Skip to content

Commit

Permalink
Adding Module Dependencies feature (#1876)
Browse files Browse the repository at this point in the history
* ran precommit

* first draft of modeule deps

* updated changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added support for directly adding modules

* minor edit

* debugging and trying

* now working as expected except for postprocess

* some significant size improvements

* fixed postprocess issue

* removed unnecesary moduledeps additions

* fixed one electron metadata issue

* removed some unneeded changes

* removed TG size printing

* added transport.py tests

* added tests for electron file changes

* fixed tests

* fixing more tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kessler-frost and pre-commit-ci[bot] committed Dec 11, 2023
1 parent d94bdaf commit 6e331d4
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Added

- Added feature to use custom python files as modules to be used in the electron function

### Changed

- SDK no longer uploads empty assets when submitting a dispatch.
Expand All @@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
important, whether the size is reported to be zero or positive does
have consequences.
- Pack deps, call_before, and call_after assets into one file.
- Changed handling of tuples and sets when building the transport graph - they will be converted to electron lists as well for now
- `qelectron_db`, `qelectron_data_exists`, `python_version`, and `covalent_version`
are now optional in the pydantic model definitions.

Expand Down
1 change: 1 addition & 0 deletions covalent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ._workflow import ( # nopycln: import
DepsBash,
DepsCall,
DepsModule,
DepsPip,
Lepton,
TransportableObject,
Expand Down
1 change: 1 addition & 0 deletions covalent/_workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .depsbash import DepsBash
from .depscall import DepsCall
from .depsmodule import DepsModule
from .depspip import DepsPip
from .electron import electron
from .lattice import lattice
Expand Down
15 changes: 8 additions & 7 deletions covalent/_workflow/depscall.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ class DepsCall(Deps):
def __init__(
self,
func=None,
args=[],
kwargs={},
args=None,
kwargs=None,
*,
retval_keyword="",
override_reserved_retval_keys=False,
):
if args is None:
args = []
if kwargs is None:
kwargs = {}
if not override_reserved_retval_keys and retval_keyword in [RESERVED_RETVAL_KEY__FILES]:
raise Exception(
raise RuntimeError(
f"The retval_keyword for the specified DepsCall uses the reserved value '{retval_keyword}' please re-name to use another return value keyword."
)

Expand All @@ -70,10 +74,7 @@ def to_dict(self) -> dict:
"""Return a JSON-serializable dictionary representation of self"""
attributes = self.__dict__.copy()
for k, v in attributes.items():
if isinstance(v, TransportableObject):
attributes[k] = v.to_dict()
else:
attributes[k] = v
attributes[k] = v.to_dict() if isinstance(v, TransportableObject) else v
return {"type": "DepsCall", "short_name": self.short_name(), "attributes": attributes}

def from_dict(self, object_dict) -> "DepsCall":
Expand Down
56 changes: 56 additions & 0 deletions covalent/_workflow/depsmodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from types import ModuleType
from typing import Union

from .depscall import DepsCall
from .transportable_object import TransportableObject


class DepsModule(DepsCall):
"""
Python modules to be imported in an electron's execution environment.
This is only used as a vehicle to send the module by reference to the
to the right place of serialization where it will instead be pickled by value.
Deps class to encapsulate python modules to be
imported in the same execution environment as the electron.
Note: This subclasses the DepsCall class due to its pre-existing
infrastructure integrations, and not because of its logical functionality.
Attributes:
module_name: A string containing the name of the module to be imported.
"""

def __init__(self, module: Union[str, ModuleType]):
if isinstance(module, str):
# Import the module on the client side
module = importlib.import_module(module)

# Temporarily pickling the module by reference
# so that it can be pickled by value when serializing
# the transport graph.
self.pickled_module = TransportableObject(module)

super().__init__()

def short_name(self):
"""Returns the short name of this class."""
return "depsmodule"
31 changes: 28 additions & 3 deletions covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from copy import deepcopy
from dataclasses import asdict
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union

from covalent._dispatcher_plugins.local import LocalDispatcher
Expand All @@ -49,6 +50,7 @@
)
from .depsbash import DepsBash
from .depscall import RESERVED_RETVAL_KEY__FILES, DepsCall
from .depsmodule import DepsModule
from .depspip import DepsPip
from .lattice import Lattice
from .transport import TransportableObject, encode_metadata
Expand Down Expand Up @@ -207,6 +209,7 @@ def func_for_op(arg_1: Union[Any, "Electron"], arg_2: Union[Any, "Electron"]) ->
metadata = encode_metadata(DEFAULT_METADATA_VALUES.copy())
executor = metadata["workflow_executor"]
executor_data = metadata["workflow_executor_data"]

op_electron = Electron(func_for_op, metadata=metadata)

if active_lattice := active_lattice_manager.get_active_lattice():
Expand Down Expand Up @@ -545,8 +548,8 @@ def connect_node_with_others(
arg_index=arg_index,
)

elif isinstance(param_value, list):

elif isinstance(param_value, (list, tuple, set)):
# Tuples and sets will also be converted to lists
def _auto_list_node(*args, **kwargs):
return list(args)

Expand Down Expand Up @@ -589,6 +592,7 @@ def _auto_dict_node(*args, **kwargs):

else:
encoded_param_value = TransportableObject.make_transportable(param_value)

parameter_node = transport_graph.add_node(
name=parameter_prefix + str(param_value),
function=None,
Expand Down Expand Up @@ -698,6 +702,7 @@ def electron(
files: List[FileTransfer] = [],
deps_bash: Union[DepsBash, List, str] = None,
deps_pip: Union[DepsPip, list] = None,
deps_module: Union[DepsModule, List[DepsModule], str, List[str]] = None,
call_before: Union[List[DepsCall], DepsCall] = None,
call_after: Union[List[DepsCall], DepsCall] = None,
) -> Callable: # sourcery skip: assign-if-exp
Expand All @@ -713,6 +718,7 @@ def electron(
executor is used by default.
deps_bash: An optional DepsBash object specifying a list of shell commands to run before `_func`
deps_pip: An optional DepsPip object specifying a list of PyPI packages to install before running `_func`
deps_module: An optional DepsModule (or similar) object specifying which user modules to load before running `_func`
call_before: An optional list of DepsCall objects specifying python functions to invoke before the electron
call_after: An optional list of DepsCall objects specifying python functions to invoke after the electron
files: An optional list of FileTransfer objects which copy files to/from remote or local filesystems.
Expand Down Expand Up @@ -757,6 +763,25 @@ def electron(
else:
internal_call_before_deps.append(DepsCall(_file_transfer_call_dep_))

if deps_module:
if isinstance(deps_module, list):
# Convert to DepsModule objects
converted_deps = []
for dep in deps_module:
if type(dep) in [str, ModuleType]:
converted_deps.append(DepsModule(dep))
else:
converted_deps.append(dep)
deps_module = converted_deps

elif type(deps_module) in [str, ModuleType]:
deps_module = [DepsModule(deps_module)]

elif isinstance(deps_module, DepsModule):
deps_module = [deps_module]

internal_call_before_deps.extend(deps_module)

if isinstance(deps_pip, DepsPip):
deps["pip"] = deps_pip
if isinstance(deps_pip, list):
Expand Down Expand Up @@ -887,7 +912,7 @@ def _build_sublattice_graph(sub: Lattice, json_parent_metadata: str, *args, **kw
)
LocalDispatcher.upload_assets(recv_manifest)

return recv_manifest.json()
return recv_manifest.model_dump_json()

except Exception as ex:
# Fall back to legacy sublattice handling
Expand Down
16 changes: 11 additions & 5 deletions covalent/_workflow/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from .depscall import DepsCall
from .depspip import DepsPip
from .postprocessing import Postprocessor
from .transport import TransportableObject, _TransportGraph, encode_metadata
from .transport import (
TransportableObject,
_TransportGraph,
add_module_deps_to_lattice_metadata,
encode_metadata,
)

if TYPE_CHECKING:
from .._results_manager.result import Result
Expand Down Expand Up @@ -234,10 +239,11 @@ def build_graph(self, *args, **kwargs) -> None:

pp = Postprocessor(lattice=self)

if get_config("sdk.exhaustive_postprocess") == "true":
pp.add_exhaustive_postprocess_node(self._bound_electrons.copy())
else:
pp.add_reconstruct_postprocess_node(retval, self._bound_electrons.copy())
with add_module_deps_to_lattice_metadata(pp, self._bound_electrons):
if get_config("sdk.exhaustive_postprocess") == "true":
pp.add_exhaustive_postprocess_node(self._bound_electrons.copy())
else:
pp.add_reconstruct_postprocess_node(retval, self._bound_electrons.copy())

self._bound_electrons = {} # Reset bound electrons

Expand Down
1 change: 1 addition & 0 deletions covalent/_workflow/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _get_electron_metadata(self) -> Dict:

pp_metadata["executor"] = pp_metadata.pop("workflow_executor")
pp_metadata["executor_data"] = pp_metadata.pop("workflow_executor_data")

return pp_metadata

def add_exhaustive_postprocess_node(self, bound_electrons: Dict) -> None:
Expand Down
91 changes: 89 additions & 2 deletions covalent/_workflow/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
import json
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Callable, Dict

Expand Down Expand Up @@ -94,6 +95,88 @@ def encode_metadata(metadata: dict) -> dict:
return encoded_metadata


@contextmanager
def pickle_modules_by_value(metadata):
"""
Pickle modules in a context manager by value.
Args:
call_before: The call before metadata.
Returns:
None
"""

call_before = metadata.get("hooks", {}).get("call_before")

if not call_before:
yield metadata
return

new_metadata = deepcopy(metadata)
call_before = new_metadata.get("hooks", {}).get("call_before")

list_of_modules = []

for i in range(len(call_before)):
# Extract the pickled module if the call before is a DepsModule
if call_before[i]["short_name"] == "depsmodule":
pickled_module = call_before[i]["attributes"]["pickled_module"]
list_of_modules.append(
TransportableObject.from_dict(pickled_module).get_deserialized()
)

# Delete the DepsModule from new_metadata of the electron
new_metadata["hooks"]["call_before"][i] = None

# Remove the None values from the call_before list of this electron
new_metadata["hooks"]["call_before"] = list(filter(None, new_metadata["hooks"]["call_before"]))

for module in list_of_modules:
cloudpickle.register_pickle_by_value(module)

yield new_metadata

for module in list_of_modules:
try:
cloudpickle.unregister_pickle_by_value(module)
except ValueError:
continue


@contextmanager
def add_module_deps_to_lattice_metadata(pp, bound_electrons: Dict[int, Any]):
old_lattice_metadata = deepcopy(pp.lattice.metadata)

# Add the module dependencies to the lattice metadata
for electron in bound_electrons.values():
call_before = electron.metadata.get("hooks", {}).get("call_before")
if call_before:
for i in range(len(call_before)):
if call_before[i]["short_name"] == "depsmodule":
if "hooks" not in pp.lattice.metadata:
pp.lattice.metadata["hooks"] = {}

if "call_before" not in pp.lattice.metadata["hooks"]:
pp.lattice.metadata["hooks"]["call_before"] = []

# Temporarily add the module metadat to the lattice metadata for postprocessing
pp.lattice.metadata["hooks"]["call_before"].append(call_before[i])

# Delete the DepsModule from the electron metadata
electron.metadata["hooks"]["call_before"][i] = None

# Remove the None values from the call_before list of this electron
electron.metadata["hooks"]["call_before"] = list(
filter(None, electron.metadata["hooks"]["call_before"])
)

yield

# Restore the old lattice metadata
pp.lattice.metadata = old_lattice_metadata


class _TransportGraph:
"""
A TransportGraph is the most essential part of the whole workflow. This contains
Expand Down Expand Up @@ -131,6 +214,7 @@ def add_node(
) -> int:
"""
Adds a node to the graph.
Also serializes the received function.
Args:
name: The name of the node.
Expand All @@ -148,14 +232,17 @@ def add_node(
if task_group_id is None:
task_group_id = node_id

with pickle_modules_by_value(metadata) as new_metadata:
serialized_function = TransportableObject(function)

# Default to gid=node_id

self._graph.add_node(
node_id,
task_group_id=task_group_id,
name=name,
function=TransportableObject(function),
metadata=metadata,
function=serialized_function,
metadata=new_metadata, # Save the new metadata without the DepsModules in it
**attr,
)

Expand Down
Loading

0 comments on commit 6e331d4

Please sign in to comment.