Skip to content

Commit

Permalink
Merge pull request #814 from facebookresearch/agent-state-consolidate
Browse files Browse the repository at this point in the history
Consolidating AgentState metadata
  • Loading branch information
JackUrb committed Jul 7, 2022
2 parents c776529 + 8b22ec1 commit f2a4d62
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 166 deletions.
97 changes: 86 additions & 11 deletions mephisto/abstractions/_subcomponents/agent_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
Union,
TYPE_CHECKING,
)
from dataclasses import dataclass
import time
import weakref
import os.path

if TYPE_CHECKING:
from mephisto.data_model.agent import Agent, OnboardingAgent
Expand All @@ -24,6 +28,22 @@

logger = get_logger(name=__name__)

METADATA_FILE = "agent_meta.json"


@dataclass
class _AgentStateMetadata:
"""
Class to track the first-class feature fields of info about an AgentState.
AgentState subclasses may choose to track additional metadata, but should
put these as attributes of the agent state subclass directly.
"""

task_start: Optional[float] = None
task_end: Optional[float] = None
# TODO other metadata fields can be initialized


# TODO(#567) File manipulations should ultimately be handled by the MephistoDB, rather than
# direct reading and writing within. This allows for a better abstraction between
Expand Down Expand Up @@ -108,24 +128,57 @@ def valid() -> List[str]:

# Implementations of an AgentState must implement the following:

@abstractmethod
def __init__(self, agent: "Agent"):
"""
Create an AgentState to track the state of an agent's work on a Unit
Implementations should initialize any required files for saving and
loading state data somewhere.
loading state data somewhere in their _load_data methods
If said file already exists based on the given agent, load that data
instead.
"""
raise NotImplementedError()
self.agent = weakref.proxy(agent)
self.load_data()

def _get_metadata_path(self) -> str:
"""Return the path we expect to store metadata in"""
data_dir = self.agent.get_data_dir()
return os.path.join(data_dir, METADATA_FILE)

def load_metadata(self) -> None:
"""Write out the metadata for this agent state to file"""
md_path = self._get_metadata_path()
if self.agent.db.key_exists(md_path):
metadata_dict = self.agent.db.read_dict(md_path)
self.metadata = _AgentStateMetadata(**metadata_dict)
else:
self.metadata = _AgentStateMetadata()

def save_metadata(self) -> None:
"""Read in the saved metadata for this agent state from file"""
metadata_dict = self.metadata.__dict__
md_path = self._get_metadata_path()
self.agent.db.write_dict(md_path, metadata_dict)

@abstractmethod
def set_init_state(self, data: Any) -> bool:
def _set_init_state(self, data: Any) -> None:
"""Set the initial state for this agent"""
raise NotImplementedError()

def set_init_state(self, data: Any) -> bool:
"""
Set the initial state for this agent, if it's not already set
Update the start time and return true if set, otherwise return false
"""
if self.get_init_state() is not None:
return False
self.metadata.task_start = time.time()
self._set_init_state(data)
self.save_data()
return True

@abstractmethod
def get_init_state(self) -> Optional[Any]:
"""
Expand All @@ -135,17 +188,23 @@ def get_init_state(self) -> Optional[Any]:
raise NotImplementedError()

@abstractmethod
def load_data(self) -> None:
def _load_data(self) -> None:
"""
Load stored data from a file to this object
"""
raise NotImplementedError()

def load_data(self) -> None:
"""
Load stored data from a file to this object, including metadata
"""
self._load_data()
self.load_metadata()

@abstractmethod
def get_data(self) -> Dict[str, Any]:
"""
Return the currently stored data for this task in the format
expected by any frontend displays
Return the currently stored data for this task
"""
raise NotImplementedError()

Expand All @@ -161,12 +220,19 @@ def get_parsed_data(self) -> Any:
return self.get_data()

@abstractmethod
def save_data(self) -> None:
def _save_data(self) -> None:
"""
Save the relevant data from this Unit to a file in the expected location
"""
raise NotImplementedError()

def save_data(self) -> None:
"""
Save the relevant data from this AgentState, including metadata
"""
self._save_data()
self.save_metadata()

@abstractmethod
def update_data(self, live_update: Dict[str, Any]) -> None:
"""
Expand All @@ -176,20 +242,29 @@ def update_data(self, live_update: Dict[str, Any]) -> None:
raise NotImplementedError()

@abstractmethod
def update_submit(self, submit_data: Dict[str, Any]) -> None:
def _update_submit(self, submit_data: Dict[str, Any]) -> None:
"""
Update this AgentState with the final submission data.
"""
raise NotImplementedError()

def update_submit(self, submit_data: Dict[str, Any]) -> None:
"""
Update this AgentState with the final submission data, marking
completion of the task in the metadata
"""
self.metadata.task_end = time.time()
self._update_submit(submit_data)
self.save_data()

def get_task_start(self) -> Optional[float]:
"""
Return the start time for this task, if it is available
"""
return 0.0
return self.metadata.task_start

def get_task_end(self) -> Optional[float]:
"""
Return the end time for this task, if it is available
"""
return 0.0
return self.metadata.task_end
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

from typing import List, Dict, Optional, Any, TYPE_CHECKING
from mephisto.abstractions.blueprint import AgentState
import os
import json
import time
import weakref
import os.path

if TYPE_CHECKING:
from mephisto.data_model.agent import Agent
Expand All @@ -31,30 +28,11 @@ def _get_empty_state(self) -> Dict[str, Optional[Dict[str, Any]]]:
return {
"inputs": None,
"outputs": None,
"times": {"task_start": 0, "task_end": 0},
}

def __init__(self, agent: "Agent"):
"""
Static agent states should store
input dict -> output dict pairs to disc
"""
self.agent = weakref.proxy(agent)
self.state: Dict[str, Optional[Dict[str, Any]]] = self._get_empty_state()
self.load_data()

def set_init_state(self, data: Any) -> bool:
def _set_init_state(self, data: Any):
"""Set the initial state for this agent"""
if self.get_init_state() is not None:
# Initial state is already set
return False
else:
self.state["inputs"] = data
times_dict = self.state["times"]
assert isinstance(times_dict, dict)
times_dict["task_start"] = time.time()
self.save_data()
return True
self.state["inputs"] = data

def get_init_state(self) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -65,27 +43,29 @@ def get_init_state(self) -> Optional[Dict[str, Any]]:
return None
return self.state["inputs"].copy()

def load_data(self) -> None:
def _load_data(self) -> None:
"""Load data for this agent from disk"""
data_dir = self.agent.get_data_dir()
data_path = os.path.join(data_dir, DATA_FILE)
if os.path.exists(data_path):
with open(data_path, "r") as data_file:
self.state = json.load(data_file)
if self.agent.db.key_exists(data_path):
self.state = self.agent.db.read_dict(data_path)
# Old compatibility with saved times
if "times" in self.state:
assert isinstance(self.state["times"], dict)
self.metadata.task_start = self.state["times"]["task_start"]
self.metadata.task_end = self.state["times"]["task_end"]
else:
self.state = self._get_empty_state()

def get_data(self) -> Dict[str, Any]:
"""Return dict of this agent's state"""
return self.state.copy()

def save_data(self) -> None:
def _save_data(self) -> None:
"""Save static agent data to disk"""
data_dir = self.agent.get_data_dir()
os.makedirs(data_dir, exist_ok=True)
out_filename = os.path.join(data_dir, DATA_FILE)
with open(out_filename, "w+") as data_file:
json.dump(self.state, data_file)
self.agent.db.write_dict(out_filename, self.state)
logger.info(f"SAVED_DATA_TO_DISC at {out_filename}")

def update_data(self, live_update: Dict[str, Any]) -> None:
Expand All @@ -94,7 +74,7 @@ def update_data(self, live_update: Dict[str, Any]) -> None:
"""
raise Exception("Static tasks should only have final act, but got live update")

def update_submit(self, submission_data: Dict[str, Any]) -> None:
def _update_submit(self, submission_data: Dict[str, Any]) -> None:
"""Move the submitted output to the local dict"""
outputs: Dict[str, Any]
assert isinstance(submission_data, dict), (
Expand All @@ -105,23 +85,3 @@ def update_submit(self, submission_data: Dict[str, Any]) -> None:
if output_files is not None:
submission_data["files"] = [f["filename"] for f in submission_data["files"]]
self.state["outputs"] = submission_data
times_dict = self.state["times"]
assert isinstance(times_dict, dict)
times_dict["task_end"] = time.time()
self.save_data()

def get_task_start(self) -> Optional[float]:
"""
Extract out and return the start time recorded for this task.
"""
stored_times = self.state["times"]
assert stored_times is not None
return stored_times["task_start"]

def get_task_end(self) -> Optional[float]:
"""
Extract out and return the end time recorded for this task.
"""
stored_times = self.state["times"]
assert stored_times is not None
return stored_times["task_end"]
19 changes: 6 additions & 13 deletions mephisto/abstractions/blueprints/mock/mock_agent_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mephisto.abstractions.blueprint import AgentState
import os
import json
import weakref

if TYPE_CHECKING:
from mephisto.data_model.agent import Agent
Expand All @@ -22,19 +21,13 @@ class MockAgentState(AgentState):

def __init__(self, agent: "Agent"):
"""Mock agent states keep everything in local memory"""
self.agent = weakref.proxy(agent)
super().__init__(agent)
self.state: Dict[str, Any] = {}
self.init_state: Any = None

def set_init_state(self, data: Any) -> bool:
def _set_init_state(self, data: Any):
"""Set the initial state for this agent"""
if self.init_state is not None:
# Initial state is already set
return False
else:
self.init_state = data
self.save_data()
return True
self.init_state = data

def get_init_state(self) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -43,22 +36,22 @@ def get_init_state(self) -> Optional[Dict[str, Any]]:
"""
return self.init_state

def load_data(self) -> None:
def _load_data(self) -> None:
"""Mock agent states have no data stored"""
pass

def get_data(self) -> Dict[str, Any]:
"""Return dict of this agent's state"""
return self.state

def save_data(self) -> None:
def _save_data(self) -> None:
"""Mock agents don't save data (yet)"""
pass

def update_data(self, live_update: Dict[str, Any]) -> None:
"""Put new data into this mock state"""
self.state = live_update

def update_submit(self, submitted_data: Dict[str, Any]) -> None:
def _update_submit(self, submitted_data: Dict[str, Any]) -> None:
"""Move the submitted data into the live state"""
self.state = submitted_data
Loading

0 comments on commit f2a4d62

Please sign in to comment.