diff --git a/mephisto/abstractions/_subcomponents/agent_state.py b/mephisto/abstractions/_subcomponents/agent_state.py index c09f3fdd5..1bdeb6cd3 100644 --- a/mephisto/abstractions/_subcomponents/agent_state.py +++ b/mephisto/abstractions/_subcomponents/agent_state.py @@ -14,6 +14,10 @@ Union, TYPE_CHECKING, ) +from dataclasses import dataclass +import time +import weakref +import os.path if TYPE_CHECKING: from mephisto.data_model.agent import Agent, OnboardingAgent @@ -24,6 +28,22 @@ logger = get_logger(name=__name__) +METADATA_FILE = "agent_meta.json" + + +@dataclass +class _AgentStateMetadata: + """ + Class to track the first-class feature fields of info about an AgentState. + + AgentState subclasses may choose to track additional metadata, but should + put these as attributes of the agent state subclass directly. + """ + + task_start: Optional[float] = None + task_end: Optional[float] = None + # TODO other metadata fields can be initialized + # TODO(#567) File manipulations should ultimately be handled by the MephistoDB, rather than # direct reading and writing within. This allows for a better abstraction between @@ -108,24 +128,57 @@ def valid() -> List[str]: # Implementations of an AgentState must implement the following: - @abstractmethod def __init__(self, agent: "Agent"): """ Create an AgentState to track the state of an agent's work on a Unit Implementations should initialize any required files for saving and - loading state data somewhere. + loading state data somewhere in their _load_data methods If said file already exists based on the given agent, load that data instead. """ - raise NotImplementedError() + self.agent = weakref.proxy(agent) + self.load_data() + + def _get_metadata_path(self) -> str: + """Return the path we expect to store metadata in""" + data_dir = self.agent.get_data_dir() + return os.path.join(data_dir, METADATA_FILE) + + def load_metadata(self) -> None: + """Write out the metadata for this agent state to file""" + md_path = self._get_metadata_path() + if self.agent.db.key_exists(md_path): + metadata_dict = self.agent.db.read_dict(md_path) + self.metadata = _AgentStateMetadata(**metadata_dict) + else: + self.metadata = _AgentStateMetadata() + + def save_metadata(self) -> None: + """Read in the saved metadata for this agent state from file""" + metadata_dict = self.metadata.__dict__ + md_path = self._get_metadata_path() + self.agent.db.write_dict(md_path, metadata_dict) @abstractmethod - def set_init_state(self, data: Any) -> bool: + def _set_init_state(self, data: Any) -> None: """Set the initial state for this agent""" raise NotImplementedError() + def set_init_state(self, data: Any) -> bool: + """ + Set the initial state for this agent, if it's not already set + + Update the start time and return true if set, otherwise return false + """ + if self.get_init_state() is not None: + return False + self.metadata.task_start = time.time() + self._set_init_state(data) + self.save_data() + return True + @abstractmethod def get_init_state(self) -> Optional[Any]: """ @@ -135,17 +188,23 @@ def get_init_state(self) -> Optional[Any]: raise NotImplementedError() @abstractmethod - def load_data(self) -> None: + def _load_data(self) -> None: """ Load stored data from a file to this object """ raise NotImplementedError() + def load_data(self) -> None: + """ + Load stored data from a file to this object, including metadata + """ + self._load_data() + self.load_metadata() + @abstractmethod def get_data(self) -> Dict[str, Any]: """ - Return the currently stored data for this task in the format - expected by any frontend displays + Return the currently stored data for this task """ raise NotImplementedError() @@ -161,12 +220,19 @@ def get_parsed_data(self) -> Any: return self.get_data() @abstractmethod - def save_data(self) -> None: + def _save_data(self) -> None: """ Save the relevant data from this Unit to a file in the expected location """ raise NotImplementedError() + def save_data(self) -> None: + """ + Save the relevant data from this AgentState, including metadata + """ + self._save_data() + self.save_metadata() + @abstractmethod def update_data(self, live_update: Dict[str, Any]) -> None: """ @@ -176,20 +242,29 @@ def update_data(self, live_update: Dict[str, Any]) -> None: raise NotImplementedError() @abstractmethod - def update_submit(self, submit_data: Dict[str, Any]) -> None: + def _update_submit(self, submit_data: Dict[str, Any]) -> None: """ Update this AgentState with the final submission data. """ raise NotImplementedError() + def update_submit(self, submit_data: Dict[str, Any]) -> None: + """ + Update this AgentState with the final submission data, marking + completion of the task in the metadata + """ + self.metadata.task_end = time.time() + self._update_submit(submit_data) + self.save_data() + def get_task_start(self) -> Optional[float]: """ Return the start time for this task, if it is available """ - return 0.0 + return self.metadata.task_start def get_task_end(self) -> Optional[float]: """ Return the end time for this task, if it is available """ - return 0.0 + return self.metadata.task_end diff --git a/mephisto/abstractions/blueprints/abstract/static_task/static_agent_state.py b/mephisto/abstractions/blueprints/abstract/static_task/static_agent_state.py index af70d187d..67ef7d1c4 100644 --- a/mephisto/abstractions/blueprints/abstract/static_task/static_agent_state.py +++ b/mephisto/abstractions/blueprints/abstract/static_task/static_agent_state.py @@ -6,10 +6,7 @@ from typing import List, Dict, Optional, Any, TYPE_CHECKING from mephisto.abstractions.blueprint import AgentState -import os -import json -import time -import weakref +import os.path if TYPE_CHECKING: from mephisto.data_model.agent import Agent @@ -31,30 +28,11 @@ def _get_empty_state(self) -> Dict[str, Optional[Dict[str, Any]]]: return { "inputs": None, "outputs": None, - "times": {"task_start": 0, "task_end": 0}, } - def __init__(self, agent: "Agent"): - """ - Static agent states should store - input dict -> output dict pairs to disc - """ - self.agent = weakref.proxy(agent) - self.state: Dict[str, Optional[Dict[str, Any]]] = self._get_empty_state() - self.load_data() - - def set_init_state(self, data: Any) -> bool: + def _set_init_state(self, data: Any): """Set the initial state for this agent""" - if self.get_init_state() is not None: - # Initial state is already set - return False - else: - self.state["inputs"] = data - times_dict = self.state["times"] - assert isinstance(times_dict, dict) - times_dict["task_start"] = time.time() - self.save_data() - return True + self.state["inputs"] = data def get_init_state(self) -> Optional[Dict[str, Any]]: """ @@ -65,13 +43,17 @@ def get_init_state(self) -> Optional[Dict[str, Any]]: return None return self.state["inputs"].copy() - def load_data(self) -> None: + def _load_data(self) -> None: """Load data for this agent from disk""" data_dir = self.agent.get_data_dir() data_path = os.path.join(data_dir, DATA_FILE) - if os.path.exists(data_path): - with open(data_path, "r") as data_file: - self.state = json.load(data_file) + if self.agent.db.key_exists(data_path): + self.state = self.agent.db.read_dict(data_path) + # Old compatibility with saved times + if "times" in self.state: + assert isinstance(self.state["times"], dict) + self.metadata.task_start = self.state["times"]["task_start"] + self.metadata.task_end = self.state["times"]["task_end"] else: self.state = self._get_empty_state() @@ -79,13 +61,11 @@ def get_data(self) -> Dict[str, Any]: """Return dict of this agent's state""" return self.state.copy() - def save_data(self) -> None: + def _save_data(self) -> None: """Save static agent data to disk""" data_dir = self.agent.get_data_dir() - os.makedirs(data_dir, exist_ok=True) out_filename = os.path.join(data_dir, DATA_FILE) - with open(out_filename, "w+") as data_file: - json.dump(self.state, data_file) + self.agent.db.write_dict(out_filename, self.state) logger.info(f"SAVED_DATA_TO_DISC at {out_filename}") def update_data(self, live_update: Dict[str, Any]) -> None: @@ -94,7 +74,7 @@ def update_data(self, live_update: Dict[str, Any]) -> None: """ raise Exception("Static tasks should only have final act, but got live update") - def update_submit(self, submission_data: Dict[str, Any]) -> None: + def _update_submit(self, submission_data: Dict[str, Any]) -> None: """Move the submitted output to the local dict""" outputs: Dict[str, Any] assert isinstance(submission_data, dict), ( @@ -105,23 +85,3 @@ def update_submit(self, submission_data: Dict[str, Any]) -> None: if output_files is not None: submission_data["files"] = [f["filename"] for f in submission_data["files"]] self.state["outputs"] = submission_data - times_dict = self.state["times"] - assert isinstance(times_dict, dict) - times_dict["task_end"] = time.time() - self.save_data() - - def get_task_start(self) -> Optional[float]: - """ - Extract out and return the start time recorded for this task. - """ - stored_times = self.state["times"] - assert stored_times is not None - return stored_times["task_start"] - - def get_task_end(self) -> Optional[float]: - """ - Extract out and return the end time recorded for this task. - """ - stored_times = self.state["times"] - assert stored_times is not None - return stored_times["task_end"] diff --git a/mephisto/abstractions/blueprints/mock/mock_agent_state.py b/mephisto/abstractions/blueprints/mock/mock_agent_state.py index 90ed16d61..6398fda0c 100644 --- a/mephisto/abstractions/blueprints/mock/mock_agent_state.py +++ b/mephisto/abstractions/blueprints/mock/mock_agent_state.py @@ -8,7 +8,6 @@ from mephisto.abstractions.blueprint import AgentState import os import json -import weakref if TYPE_CHECKING: from mephisto.data_model.agent import Agent @@ -22,19 +21,13 @@ class MockAgentState(AgentState): def __init__(self, agent: "Agent"): """Mock agent states keep everything in local memory""" - self.agent = weakref.proxy(agent) + super().__init__(agent) self.state: Dict[str, Any] = {} self.init_state: Any = None - def set_init_state(self, data: Any) -> bool: + def _set_init_state(self, data: Any): """Set the initial state for this agent""" - if self.init_state is not None: - # Initial state is already set - return False - else: - self.init_state = data - self.save_data() - return True + self.init_state = data def get_init_state(self) -> Optional[Dict[str, Any]]: """ @@ -43,7 +36,7 @@ def get_init_state(self) -> Optional[Dict[str, Any]]: """ return self.init_state - def load_data(self) -> None: + def _load_data(self) -> None: """Mock agent states have no data stored""" pass @@ -51,7 +44,7 @@ def get_data(self) -> Dict[str, Any]: """Return dict of this agent's state""" return self.state - def save_data(self) -> None: + def _save_data(self) -> None: """Mock agents don't save data (yet)""" pass @@ -59,6 +52,6 @@ def update_data(self, live_update: Dict[str, Any]) -> None: """Put new data into this mock state""" self.state = live_update - def update_submit(self, submitted_data: Dict[str, Any]) -> None: + def _update_submit(self, submitted_data: Dict[str, Any]) -> None: """Move the submitted data into the live state""" self.state = submitted_data diff --git a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_agent_state.py b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_agent_state.py index 881b08057..68723a786 100644 --- a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_agent_state.py +++ b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_agent_state.py @@ -6,10 +6,8 @@ from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING from mephisto.abstractions.blueprint import AgentState -import os -import json +import os.path import time -import weakref if TYPE_CHECKING: from mephisto.data_model.agent import Agent @@ -22,31 +20,9 @@ class ParlAIChatAgentState(AgentState): containing every act from the ParlAI world. """ - def __init__(self, agent: "Agent"): - """ - Create an AgentState to track the state of an agent's work on a Unit - - Initialize with an existing file if it exists. - """ - self.agent = weakref.proxy(agent) - data_file = self._get_expected_data_file() - if os.path.exists(data_file): - self.load_data() - else: - self.messages: List[Dict[str, Any]] = [] - self.final_submission: Optional[Dict[str, Any]] = None - self.init_data = None - self.save_data() - - def set_init_state(self, data: Any) -> bool: + def _set_init_state(self, data: Any): """Set the initial state for this agent""" - if self.init_data is not None: - # Initial state is already set - return False - else: - self.init_data = data - self.save_data() - return True + self.init_data: Optional[Any] = data def get_init_state(self) -> Optional[Dict[str, Any]]: """ @@ -60,14 +36,17 @@ def get_init_state(self) -> Optional[Dict[str, Any]]: def _get_expected_data_file(self) -> str: """Return the place we would expect to find data for this agent state""" agent_dir = self.agent.get_data_dir() - os.makedirs(agent_dir, exist_ok=True) return os.path.join(agent_dir, "state.json") - def load_data(self) -> None: + def _load_data(self) -> None: """Load stored data from a file to this object""" agent_file = self._get_expected_data_file() - with open(agent_file, "r") as state_json: - state = json.load(state_json) + if not self.agent.db.key_exists(agent_file): + self.messages: List[Dict[str, Any]] = [] + self.final_submission: Optional[Dict[str, Any]] = None + self.init_data = None + else: + state = self.agent.db.read_dict(agent_file) self.messages = state["outputs"]["messages"] self.init_data = state["inputs"] self.final_submission = state["outputs"].get("final_submission") @@ -119,11 +98,10 @@ def get_task_end(self) -> float: """ return self.messages[-1]["timestamp"] - def save_data(self) -> None: + def _save_data(self) -> None: """Save all messages from this agent to""" agent_file = self._get_expected_data_file() - with open(agent_file, "w+") as state_json: - json.dump(self.get_data(), state_json) + self.agent.db.write_dict(agent_file, self.get_data()) def update_data(self, live_update: Dict[str, Any]) -> None: """ @@ -133,7 +111,6 @@ def update_data(self, live_update: Dict[str, Any]) -> None: self.messages.append(live_update) self.save_data() - def update_submit(self, submitted_data: Dict[str, Any]) -> None: + def _update_submit(self, submitted_data: Dict[str, Any]) -> None: """Append any final submission to this state""" self.final_submission = submitted_data - self.save_data() diff --git a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_agent_state.py b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_agent_state.py index 3f986f449..ccda68d9c 100644 --- a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_agent_state.py +++ b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_agent_state.py @@ -35,32 +35,9 @@ class RemoteProcedureAgentState(AgentState): Holds information about tasks with live interactions in a remote query model. """ - def __init__(self, agent: "Agent"): - """ - Create an agent state that keeps track of incoming actions from the frontend client - Initialize with an existing file if it exists. - """ - self.agent = weakref.proxy(agent) - data_file = self._get_expected_data_file() - if os.path.exists(data_file): - self.load_data() - else: - self.requests: Dict[str, RemoteRequest] = {} - self.start_time = time.time() - self.end_time = -1.0 - self.init_data: Optional[Dict[str, Any]] = None - self.final_submission: Optional[Dict[str, Any]] = None - self.save_data() - - def set_init_state(self, data: Any) -> bool: + def _set_init_state(self, data: Any): """Set the initial state for this agent""" - if self.init_data is not None: - # Initial state is already set - return False - else: - self.init_data = data - self.save_data() - return True + self.init_data: Optional[Dict[str, Any]] = data def get_init_state(self) -> Optional[Dict[str, Any]]: """ @@ -82,16 +59,21 @@ def _get_expected_data_file(self) -> str: os.makedirs(agent_dir, exist_ok=True) return os.path.join(agent_dir, "state.json") - def load_data(self) -> None: + def _load_data(self) -> None: """Load stored data from a file to this object""" + self.requests: Dict[str, RemoteRequest] = {} + self.init_data = None + self.final_submission: Optional[Dict[str, Any]] = None agent_file = self._get_expected_data_file() - with open(agent_file, "r") as state_json: - state = json.load(state_json) + if self.agent.db.key_exists(agent_file): + state = self.agent.db.read_dict(agent_file) self.requests = {x["uuid"]: RemoteRequest(**x) for x in state["requests"]} self.init_data = state["init_data"] self.final_submission = state["final_submission"] - self.start_time = state["start_time"] - self.end_time = state["end_time"] + # Backwards compatibility for times + if "start_time" in state: + self.metadata.task_start = state["start_time"] + self.metadata.task_end = state["end_time"] def get_data(self) -> Dict[str, Any]: """Return dict with the messages of this agent""" @@ -99,8 +81,8 @@ def get_data(self) -> Dict[str, Any]: "final_submission": self.final_submission, "init_data": self.init_data, "requests": [r.to_dict() for r in self.requests.values()], - "start_time": self.start_time, - "end_time": self.end_time, + "start_time": self.metadata.task_start, + "end_time": self.metadata.task_end, } def get_parsed_data(self) -> Dict[str, Any]: @@ -108,23 +90,10 @@ def get_parsed_data(self) -> Dict[str, Any]: # TODO implement actually getting this data return self.get_data() - def get_task_start(self) -> float: - """ - Return the start time for this task - """ - return self.start_time - - def get_task_end(self) -> float: - """ - Return the end time for this task - """ - return self.end_time - - def save_data(self) -> None: + def _save_data(self) -> None: """Save all messages from this agent to""" agent_file = self._get_expected_data_file() - with open(agent_file, "w+") as state_json: - json.dump(self.get_data(), state_json) + self.agent.db.write_dict(agent_file, self.get_data()) def update_data(self, live_update: Dict[str, Any]) -> None: """ @@ -152,8 +121,6 @@ def update_data(self, live_update: Dict[str, Any]) -> None: ) self.requests[live_update["request_id"]] = request - def update_submit(self, submitted_data: Dict[str, Any]) -> None: + def _update_submit(self, submitted_data: Dict[str, Any]) -> None: """Append any final submission to this state""" self.final_submission = submitted_data - self.end_time = time.time() - self.save_data() diff --git a/mephisto/abstractions/database.py b/mephisto/abstractions/database.py index 02be2bff6..52d0ad0ca 100644 --- a/mephisto/abstractions/database.py +++ b/mephisto/abstractions/database.py @@ -16,6 +16,7 @@ get_valid_provider_types, ) from typing import Mapping, Optional, Any, List, Dict +import enum from mephisto.data_model.agent import Agent, OnboardingAgent from mephisto.data_model.unit import Unit from mephisto.data_model.assignment import Assignment @@ -1059,3 +1060,30 @@ def revoke_qualification(self, qualification_id: str, worker_id: str) -> None: return self._revoke_qualification( qualification_id=qualification_id, worker_id=worker_id ) + + # File/blob manipulation methods + + @abstractmethod + def write_dict(self, path_key: str, target_dict: Dict[str, Any]): + """Write an object to the given key""" + raise NotImplementedError() + + @abstractmethod + def read_dict(self, path_key: str) -> Dict[str, Any]: + """Return the dict loaded from the given path key""" + raise NotImplementedError() + + @abstractmethod + def write_text(self, path_key: str, data_string: str): + """Write the given text to the given key""" + raise NotImplementedError() + + @abstractmethod + def read_text(self, path_key: str) -> str: + """Get text data stored at the given key""" + raise NotImplementedError() + + @abstractmethod + def key_exists(self, path_key: str) -> bool: + """See if the given path refers to a known file""" + raise NotImplementedError() diff --git a/mephisto/abstractions/databases/local_database.py b/mephisto/abstractions/databases/local_database.py index 93a19806b..00b0226d4 100644 --- a/mephisto/abstractions/databases/local_database.py +++ b/mephisto/abstractions/databases/local_database.py @@ -27,6 +27,8 @@ import sqlite3 from sqlite3 import Connection, Cursor import threading +import os +import json from mephisto.utils.logger_core import get_logger @@ -1465,3 +1467,42 @@ def _find_onboarding_agents( ) for r in rows ] + + # File/blob manipulation methods + + def _assert_path_in_domain(self, path_key: str) -> None: + """Helper method to ensure we only manage data we're supposed to""" + assert path_key.startswith( + self.db_root + ), f"Accessing invalid key {path_key} for root {self.db_root}" + + def write_dict(self, path_key: str, target_dict: Dict[str, Any]): + """Write an object to the given key""" + self._assert_path_in_domain(path_key) + os.makedirs(os.path.dirname(path_key), exist_ok=True) + with open(path_key, "w+") as data_file: + json.dump(target_dict, data_file) + + def read_dict(self, path_key: str) -> Dict[str, Any]: + """Return the dict loaded from the given path key""" + self._assert_path_in_domain(path_key) + with open(path_key, "r") as data_file: + return json.load(data_file) + + def write_text(self, path_key: str, data_string: str): + """Write the given text to the given key""" + self._assert_path_in_domain(path_key) + os.makedirs(os.path.dirname(path_key), exist_ok=True) + with open(path_key, "w+") as data_file: + data_file.write(data_string) + + def read_text(self, path_key: str) -> str: + """Get text data stored at the given key""" + self._assert_path_in_domain(path_key) + with open(path_key, "r") as data_file: + return data_file.read() + + def key_exists(self, path_key: str) -> bool: + """See if the given path refers to a known file""" + self._assert_path_in_domain(path_key) + return os.path.exists(path_key) diff --git a/mephisto/data_model/agent.py b/mephisto/data_model/agent.py index 943eb7924..31cac3447 100644 --- a/mephisto/data_model/agent.py +++ b/mephisto/data_model/agent.py @@ -238,7 +238,7 @@ def await_submit(self, timeout: Optional[int] = None) -> bool: elif status == AgentState.STATUS_RETURNED: raise AgentReturnedError(self.db_id) elif status == AgentState.STATUS_TIMEOUT: - raise AgentTimeoutError(self.db_id) + raise AgentTimeoutError(timeout, self.db_id) # Wait for the status change self.did_submit.wait(timeout=timeout) if not self.did_submit.is_set(): diff --git a/mephisto/operations/client_io_handler.py b/mephisto/operations/client_io_handler.py index 57ac038a7..c0c75c568 100644 --- a/mephisto/operations/client_io_handler.py +++ b/mephisto/operations/client_io_handler.py @@ -26,7 +26,7 @@ PACKET_TYPE_ERROR, ) from mephisto.abstractions.blueprint import AgentState -from mephisto.data_model.agent import Agent, OnboardingAgent +from mephisto.data_model.agent import Agent, OnboardingAgent, _AgentBase from mephisto.operations.datatypes import LiveTaskRun from mephisto.abstractions._subcomponents.channel import Channel, STATUS_CHECK_TIME from typing import Dict, Tuple, Union, Optional, List, Any, TYPE_CHECKING @@ -289,6 +289,7 @@ def _on_submit_onboarding(self, packet: Packet, channel_id: str) -> None: assert ( "onboarding_data" in packet.data ), f"Onboarding packet {packet} submitted without data" + agent: Optional["_AgentBase"] live_run = self.get_live_run() onboarding_id = packet.subject_id if onboarding_id not in live_run.worker_pool.onboarding_agents: