Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent oversampling of study submissions by any single worker #1080

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mephisto/abstractions/providers/mturk/mturk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def create_hit_type(
has_locale_qual = True
locale_requirements += existing_qualifications

if not has_locale_qual and not client_is_sandbox(client):
is_sandbox = client_is_sandbox(client)
if not has_locale_qual and not is_sandbox:
allowed_locales = get_config_arg("mturk", "allowed_locales")
if allowed_locales is None:
allowed_locales = [
Expand All @@ -458,6 +459,9 @@ def create_hit_type(
}
)

if is_sandbox:
hit_reward = 0

# Create the HIT type
response = client.create_hit_type(
AutoApprovalDelayInSeconds=auto_approve_delay,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def _base_request(
else:
result = response.json()

logger.debug(f"{log_prefix} Response: {result}")

return result

except ProlificException:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def remove_participants_from_group(
https://docs.prolific.co/docs/api-docs/public/#tag/
Participant-Groups/paths/~1api~1v1~1participant-groups~1%7Bid%7D~1participants~1/delete
"""
from mephisto.utils.logger_core import get_logger

logger = get_logger(name=__name__)
endpoint = cls.list_participants_for_group_api_endpoint.format(id=id)
params = dict(participant_ids=participant_ids)
response_json = cls.delete(endpoint, params=params)
Expand Down
22 changes: 20 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,29 @@ def new_from_provider_data(

assert isinstance(unit, ProlificUnit), "Can only register Prolific agents to Prolific units"

agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit
task_run: "TaskRun" = agent.get_task_run()

prolific_study_id = provider_data["prolific_study_id"]
prolific_submission_id = provider_data["assignment_id"]
unit.register_from_provider_data(prolific_study_id, prolific_submission_id)

logger.debug("Prolific Submission has been registered successfully")

return super().new_from_provider_data(db, worker, unit, provider_data)
# Check whether we need to prevent this worker from future submissions in this Task
if not worker.can_send_more_submissions_for_task(task_run):
# Excluding worker from Participant Group (instead of adding to Block List)
# only because Prolific cannot update Block List for an in-progress Study
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def approve_work(
self,
Expand Down Expand Up @@ -239,7 +255,6 @@ def get_status(self) -> str:
if prolific_submission_id:
prolific_submission = prolific_utils.get_submission(client, prolific_submission_id)
else:
# TODO: Not sure about this
self.update_status(AgentState.STATUS_EXPIRED)
return self.db_status

Expand All @@ -249,6 +264,9 @@ def get_status(self) -> str:

if prolific_submission.status == SubmissionStatus.RESERVED:
provider_status = local_status
elif prolific_submission.status == SubmissionStatus.ACTIVE:
# We don't need to map this status in our DB
pass
else:
provider_status = SUBMISSION_STATUS_TO_AGENT_STATE_MAP.get(
prolific_submission.status,
Expand Down
10 changes: 6 additions & 4 deletions mephisto/abstractions/providers/prolific/prolific_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def get_blocked_workers(self) -> List[dict]:
results = c.fetchall()
return results

def get_bloked_participant_ids(self) -> List[str]:
def get_blocked_participant_ids(self) -> List[str]:
return [w["worker_id"] for w in self.get_blocked_workers()]

def ensure_unit_exists(self, unit_id: str) -> None:
Expand Down Expand Up @@ -629,7 +629,7 @@ def find_qualifications_by_ids(
task_run_ids: Optional[List[str]] = None,
) -> List[dict]:
"""Find qualifications by Mephisto ids of qualifications and task runs"""
if not qualification_ids:
if not (qualification_ids or task_run_ids):
return []

with self.table_access_condition, self._get_connection() as conn:
Expand All @@ -645,12 +645,14 @@ def find_qualifications_by_ids(
task_run_ids_block = ""
if task_run_ids:
task_run_ids_str = ",".join([f'"{tid}"' for tid in task_run_ids])
task_run_ids_block = f"AND task_run_id IN ({task_run_ids_str})"
task_run_ids_block = f"task_run_id IN ({task_run_ids_str})"

where_block = " AND ".join(filter(bool, [qualification_ids_block, task_run_ids_block]))

c.execute(
f"""
SELECT * FROM qualifications
WHERE {qualification_ids_block} {task_run_ids_block};
WHERE {where_block};
"""
)
results = c.fetchall()
Expand Down
42 changes: 31 additions & 11 deletions mephisto/abstractions/providers/prolific/prolific_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.abstractions.providers.prolific.prolific_worker import ProlificWorker
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
from mephisto.data_model.worker import Worker
from mephisto.operations.registry import register_mephisto_abstraction
from mephisto.utils.logger_core import get_logger
from mephisto.utils.qualifications import QualificationType
Expand All @@ -44,14 +45,13 @@
from .api.exceptions import ProlificException

if TYPE_CHECKING:
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit
from mephisto.data_model.worker import Worker
from mephisto.data_model.requester import Requester
from mephisto.data_model.agent import Agent
from mephisto.abstractions.blueprint import SharedTaskState


DEFAULT_FRAME_HEIGHT = 0
DEFAULT_PROLIFIC_GROUP_NAME_ALLOW_LIST = "Allow list"
DEFAULT_PROLIFIC_GROUP_NAME_BLOCK_LIST = "Block list"
Expand Down Expand Up @@ -173,12 +173,13 @@ def _get_client(self, requester_name: str) -> ProlificClient:
def _get_qualified_workers(
self,
qualifications: List[QualificationType],
bloked_participant_ids: List[str],
blocked_participant_ids: List[str],
task_run: "TaskRun",
) -> List["Worker"]:
qualified_workers = []
workers: List[Worker] = self.db.find_workers(provider_type="prolific")
# `worker_name` is Prolific Participant ID in provider-specific datastore
available_workers = [w for w in workers if w.worker_name not in bloked_participant_ids]
available_workers = [w for w in workers if w.worker_name not in blocked_participant_ids]

for worker in available_workers:
if worker_is_qualified(worker, qualifications):
Expand Down Expand Up @@ -213,6 +214,20 @@ def _create_participant_group_with_qualified_workers(
)
return prolific_participant_group

def _get_excluded_participant_ids(self, task_run: "TaskRun") -> List[str]:
"""Find participant_ids that exceeded `maximum_units_per_worker` cap within this Task"""
task: "Task" = task_run.get_task()
task_units: List["Unit"] = self.db.find_units(task_id=task.db_id)

excluded_participant_ids: List[str] = []
for unit in task_units:
if unit.worker_id:
worker: "Worker" = Worker.get(self.db, unit.worker_id)
if not worker.can_send_more_submissions_for_task(task_run):
excluded_participant_ids.append(worker.worker_name)

return list(set(excluded_participant_ids))

def setup_resources_for_task_run(
self,
task_run: "TaskRun",
Expand Down Expand Up @@ -261,11 +276,12 @@ def setup_resources_for_task_run(
title=args.provider.prolific_project_name,
)

blocked_participant_ids = self.datastore.get_bloked_participant_ids()

blocked_participant_ids: List[str] = self.datastore.get_blocked_participant_ids()
excluded_participant_ids: List[str] = self._get_excluded_participant_ids(task_run)
# If no Mephisto qualifications found,
# we need to block Mephisto workers on Prolific as well
if blocked_participant_ids:
participant_ids_to_add_to_block_list = blocked_participant_ids + excluded_participant_ids
if participant_ids_to_add_to_block_list:
new_prolific_specific_qualifications = []
# Add empty Blacklist in case if there is not in state or config
blacklist_qualification = DictConfig(
Expand All @@ -285,27 +301,31 @@ def setup_resources_for_task_run(
whitelist_qualification = prolific_specific_qualification
prev_value = whitelist_qualification["white_list"]
whitelist_qualification["white_list"] = [
p for p in prev_value if p not in blocked_participant_ids
p for p in prev_value if p not in participant_ids_to_add_to_block_list
]
new_prolific_specific_qualifications.append(whitelist_qualification)
elif name == ParticipantGroupEligibilityRequirement.name:
# Remove blocked Participat IDs from Participant Group Eligibility Requirement
client.ParticipantGroups.remove_participants_from_group(
id=prolific_specific_qualification["id"],
participant_ids=blocked_participant_ids,
participant_ids=participant_ids_to_add_to_block_list,
)
else:
new_prolific_specific_qualifications.append(prolific_specific_qualification)

# Set Blacklist Eligibility Requirement
blacklist_qualification["black_list"] = list(
set(blacklist_qualification["black_list"] + blocked_participant_ids)
set(blacklist_qualification["black_list"] + participant_ids_to_add_to_block_list)
)
new_prolific_specific_qualifications.append(blacklist_qualification)
prolific_specific_qualifications = new_prolific_specific_qualifications

if qualifications:
qualified_workers = self._get_qualified_workers(qualifications, blocked_participant_ids)
qualified_workers = self._get_qualified_workers(
qualifications,
participant_ids_to_add_to_block_list,
task_run,
)

if qualified_workers:
prolific_workers_ids = [w.worker_name for w in qualified_workers]
Expand Down
22 changes: 21 additions & 1 deletion mephisto/abstractions/providers/prolific/prolific_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ def compose_completion_codes(code_suffix: str) -> List[dict]:
),
],
),
dict(
code=f"{constants.StudyCodeType.OTHER}_{code_suffix}",
code_type=constants.StudyCodeType.OTHER,
actions=[
dict(
action=constants.StudyAction.MANUALLY_REVIEW,
),
],
),
]

# Task info
Expand Down Expand Up @@ -578,7 +587,10 @@ def remove_worker_qualification(
*args,
**kwargs,
) -> None:
"""Remove a qualification for the given worker (remove a worker from a Participant Group)"""
"""
Remove a qualification for the given worker (remove a worker from a Participant Group).
NOTE: If a participant is not a member of the group, they will be ignored (from API Docs)
"""
try:
client.ParticipantGroups.remove_participants_from_group(
id=qualification_id,
Expand All @@ -591,6 +603,14 @@ def remove_worker_qualification(
raise


def exclude_worker_from_participant_group(
client: ProlificClient,
worker_id: str,
participant_group_id: str,
):
remove_worker_qualification(client, worker_id, participant_group_id)


def pay_bonus(
client: ProlificClient,
task_run_config: "DictConfig",
Expand Down
43 changes: 41 additions & 2 deletions mephisto/abstractions/providers/prolific/prolific_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from typing import Tuple
from typing import TYPE_CHECKING

from omegaconf import DictConfig

from mephisto.abstractions.providers.prolific import prolific_utils
from mephisto.abstractions.providers.prolific.api.client import ProlificClient
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
Expand All @@ -28,6 +26,7 @@
from mephisto.abstractions.providers.prolific.prolific_requester import ProlificRequester
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.data_model.requester import Requester
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit

Expand Down Expand Up @@ -181,6 +180,46 @@ def unblock_worker(self, reason: str, requester: "Requester") -> Tuple[bool, str

return True, ""

def exclude_worker_from_task(
self,
task_run: Optional["TaskRun"] = None,
) -> Tuple[bool, str]:
"""Exclude this worker from current Task"""
logger.debug(f"{self.log_prefix}Excluding worker {self.worker_name} from Prolific")

# 1. Get Client
requester: "ProlificRequester" = task_run.get_requester()
client = self._get_client(requester.requester_name)

# 2. Find TaskRun IDs that are related to current Task
task: "Task" = task_run.get_task()
all_task_run_ids_for_task: List[str] = [t.db_id for t in task.get_runs()]

# 3. Select all Participant Group IDs that are related to the Task
datastore_qualifications = self.datastore.find_qualifications_by_ids(
task_run_ids=all_task_run_ids_for_task,
)
prolific_participant_group_ids = [
q["prolific_participant_group_id"] for q in datastore_qualifications
]

logger.debug(
f"{self.log_prefix}Found {len(prolific_participant_group_ids)} Participant Groups: "
f"{prolific_participant_group_ids}"
)

# 4. Exclude the Worker from Prolific Participant Groups
for prolific_participant_group_id in prolific_participant_group_ids:
prolific_utils.exclude_worker_from_participant_group(
client,
self.worker_name,
prolific_participant_group_id,
)

logger.debug(f"{self.log_prefix}Worker {self.worker_name} excluded")

return True, ""

def is_blocked(self, requester: "Requester") -> bool:
"""Determine if a worker is blocked"""
task_run = self._get_first_task_run(requester)
Expand Down
11 changes: 11 additions & 0 deletions mephisto/data_model/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,17 @@ def new_from_provider_data(
agent = cls.new(db, worker, unit)
unit.worker_id = worker.db_id
agent._unit = unit

# Prevent sending more units to worker if worker exceeded submission cap within this Task
task_run: "TaskRun" = agent.get_task_run()
if not worker.can_send_more_submissions_for_task(task_run):
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)

return agent

def get_status(self) -> str:
Expand Down
12 changes: 8 additions & 4 deletions mephisto/data_model/task_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]:

# Cannot pair with self
units: List["Unit"] = []
for unit_set in unit_assigns.values():
is_self_set = map(lambda u: u.worker_id == worker.db_id, unit_set)
if not any(is_self_set):
units += unit_set
for unit_list in unit_assigns.values():
self_linked_units = [
u
for u in unit_list
if u.worker_id == worker.db_id and u.db_status == AssignmentState.LAUNCHED
]
if not self_linked_units:
units += unit_list

# Valid units must be launched and must not be special units (negative indices)
# Can use db_status directly rather than polling in the critical path, as in
Expand Down
Loading
Loading