Skip to content

Commit

Permalink
Prevent oversampling of study submissions by any single worker
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-paul committed Oct 19, 2023
1 parent e7230c3 commit 3345689
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 6 deletions.
30 changes: 29 additions & 1 deletion mephisto/abstractions/providers/prolific/prolific_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,41 @@ def new_from_provider_data(

assert isinstance(unit, ProlificUnit), "Can only register Prolific agents to Prolific units"

agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit
task_run: "TaskRun" = agent.get_task_run()

# In case provider API wasn't responsive, we ensure this submission
# doesn't exceed per-worker cap for this Task. Othewrwise don't process submission.
if not worker.can_send_more_submissions_for_task(task_run):
logger.info(
f"Submission from worker \"{worker.db_id}\" is over the Task's submission cap."
)
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)
return agent

prolific_study_id = provider_data["prolific_study_id"]
prolific_submission_id = provider_data["assignment_id"]
unit.register_from_provider_data(prolific_study_id, prolific_submission_id)

logger.debug("Prolific Submission has been registered successfully")

return super().new_from_provider_data(db, worker, unit, provider_data)
# Check whether we need to prevent this worker from future submissions in this Task
if not worker.can_send_more_submissions_for_task(task_run):
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def approve_work(
self,
Expand Down
10 changes: 8 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,18 @@ def _get_qualified_workers(
self,
qualifications: List[QualificationType],
bloked_participant_ids: List[str],
task_run: "TaskRun",
) -> List["Worker"]:
qualified_workers = []
workers: List[Worker] = self.db.find_workers(provider_type="prolific")
# `worker_name` is Prolific Participant ID in provider-specific datastore
available_workers = [w for w in workers if w.worker_name not in bloked_participant_ids]

for worker in available_workers:
if worker_is_qualified(worker, qualifications):
if (
worker.can_send_more_submissions_for_task(task_run) and
worker_is_qualified(worker, qualifications)
):
qualified_workers.append(worker)

return qualified_workers
Expand Down Expand Up @@ -305,7 +309,9 @@ def setup_resources_for_task_run(
prolific_specific_qualifications = new_prolific_specific_qualifications

if qualifications:
qualified_workers = self._get_qualified_workers(qualifications, blocked_participant_ids)
qualified_workers = self._get_qualified_workers(
qualifications, blocked_participant_ids, task_run,
)

if qualified_workers:
prolific_workers_ids = [w.worker_name for w in qualified_workers]
Expand Down
13 changes: 12 additions & 1 deletion mephisto/abstractions/providers/prolific/prolific_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,10 @@ def remove_worker_qualification(
*args,
**kwargs,
) -> None:
"""Remove a qualification for the given worker (remove a worker from a Participant Group)"""
"""
Remove a qualification for the given worker (remove a worker from a Participant Group).
NOTE: If a participant is not a member of the group, they will be ignored (from API Docs)
"""
try:
client.ParticipantGroups.remove_participants_from_group(
id=qualification_id,
Expand All @@ -591,6 +594,14 @@ def remove_worker_qualification(
raise


def exclude_worker_from_participant_group(
client: ProlificClient,
worker_id: str,
participant_group_id: str,
):
remove_worker_qualification(client, worker_id, participant_group_id)


def pay_bonus(
client: ProlificClient,
task_run_config: "DictConfig",
Expand Down
42 changes: 40 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from typing import Tuple
from typing import TYPE_CHECKING

from omegaconf import DictConfig

from mephisto.abstractions.providers.prolific import prolific_utils
from mephisto.abstractions.providers.prolific.api.client import ProlificClient
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
Expand All @@ -28,6 +26,7 @@
from mephisto.abstractions.providers.prolific.prolific_requester import ProlificRequester
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.data_model.requester import Requester
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit

Expand Down Expand Up @@ -181,6 +180,45 @@ def unblock_worker(self, reason: str, requester: "Requester") -> Tuple[bool, str

return True, ""

def exclude_worker_from_task(
self, task_run: Optional["TaskRun"] = None,
) -> Tuple[bool, str]:
"""Exclude this worker from current Task"""
logger.debug(f"{self.log_prefix}Excluding worker {self.worker_name} from Prolific")

# 1. Get Client
requester: "ProlificRequester" = task_run.get_requester()
client = self._get_client(requester.requester_name)

# 2. Find TaskRun IDs that are related to current Task
task: "Task" = task_run.get_task()
all_task_run_ids_for_task: List[str] = [t.db_id for t in task.get_runs()]

# 3. Select all Participant Group IDs that are related to the Task
datastore_qualifications = self.datastore.find_qualifications_by_ids(
task_run_ids=all_task_run_ids_for_task,
)
prolific_participant_group_ids = [
q["prolific_participant_group_id"] for q in datastore_qualifications
]

logger.debug(
f"{self.log_prefix}Found {len(prolific_participant_group_ids)} Participant Groups: "
f"{prolific_participant_group_ids}"
)

# 4. Exclude the Worker from Prolific Participant Groups
for prolific_participant_group_id in prolific_participant_group_ids:
prolific_utils.exclude_worker_from_participant_group(
client,
self.worker_name,
prolific_participant_group_id,
)

logger.debug(f"{self.log_prefix}Worker {self.worker_name} excluded")

return True, ""

def is_blocked(self, requester: "Requester") -> bool:
"""Determine if a worker is blocked"""
task_run = self._get_first_task_run(requester)
Expand Down
12 changes: 12 additions & 0 deletions mephisto/data_model/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,18 @@ def new_from_provider_data(
agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit

# In case provider API wasn't responsive, we ensure this submission
# doesn't exceed per-worker cap for this Task. Othewrwise don't process submission.
task_run: "TaskRun" = agent.get_task_run()
if not worker.can_send_more_submissions_for_task(task_run):
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def get_status(self) -> str:
Expand Down
10 changes: 10 additions & 0 deletions mephisto/data_model/task_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ class TaskRunArgs:
},
)

max_submissions_per_worker: Optional[int] = field(
default=None,
metadata={
"help": (
"Maximum submissions that a worker can submit on across all "
"tasks that share this task_name. (0 is infinite)"
)
},
)

@classmethod
def get_mock_params(cls) -> str:
"""Returns a param string with default / mock arguments to use for testing"""
Expand Down
31 changes: 31 additions & 0 deletions mephisto/data_model/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
MephistoDataModelComponentMixin,
)
from typing import Any, List, Optional, Mapping, Tuple, Dict, Type, Tuple, TYPE_CHECKING

from mephisto.data_model.constants.assignment_state import AssignmentState
from mephisto.utils.logger_core import get_logger

logger = get_logger(name=__name__)
Expand Down Expand Up @@ -259,6 +261,16 @@ def unblock_worker(self, reason: str, requester: "Requester") -> bool:
"""unblock a blocked worker for the specified reason"""
raise NotImplementedError()

def exclude_worker_from_task(
self, task_run: Optional["TaskRun"] = None,
) -> Tuple[bool, str]:
"""
Prevent this worker from further participation in current Task.
(Note that scope of exclusion is only within the current Task,
whereas block lists or altering worker qualifications would affect future Tasks.)
"""
pass

def is_blocked(self, requester: "Requester") -> bool:
"""Determine if a worker is blocked"""
raise NotImplementedError()
Expand All @@ -267,6 +279,25 @@ def is_eligible(self, task_run: "TaskRun") -> bool:
"""Determine if this worker is eligible for the given task run"""
raise NotImplementedError()

def can_send_more_submissions_for_task(self, task_run: "TaskRun") -> bool:
"""Check whether a worker is allowed to send any more submissions within current Task"""
max_submissions_per_worker = task_run.args.max_submissions_per_worker

# By default, worker can send any amount of submissions
if max_submissions_per_worker is None:
return True

# Find all completed units byt this worker for current task
task_units = self.db.find_units(task_id=task_run.task_id, worker_id=self.db_id)
completed_task_units = [
u for u in task_units if u.get_status() in AssignmentState.completed()
]

if len(completed_task_units) >= max_submissions_per_worker:
return False

return True

def register(self, args: Optional[Dict[str, str]] = None) -> None:
"""Register this worker with the crowdprovider, if necessary"""
pass
Expand Down

0 comments on commit 3345689

Please sign in to comment.