Skip to content

Commit

Permalink
Add more tests for zntrack.apply (#800)
Browse files Browse the repository at this point in the history
* write test for issue with `zntrack.apply` can not pickle local object #799

* add assertions

* more testing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] committed May 17, 2024
1 parent d6f96e9 commit f7d2d7c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/integration/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,36 @@ def test_apply(proj_path, eager) -> None:
assert a.outs == ["a", "b"]
assert b.outs == "a-b"
assert c.outs == "a-b-c"


@pytest.mark.parametrize("attribute", [True, False])
@pytest.mark.parametrize("eager", [True, False])
def test_deps_apply(proj_path, eager, attribute):
"""Test connecting applied nodes to other nodes."""

project = zntrack.Project()

JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join")

assert issubclass(JoinedParamsToOuts, zntrack.Node)

with project:
a = zntrack.examples.ParamsToOuts(params=["a", "b"])
b = JoinedParamsToOuts(params=["a", "b"])
c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"])

if attribute:
x3 = zntrack.examples.AddNodeAttributes(a=b.outs, b=c.outs)
else:
x3 = zntrack.examples.AddNodes2(a=b, b=c)

project.run(eager=eager)

x3.load()

assert isinstance(a, zntrack.Node)
assert isinstance(b, zntrack.Node)
assert isinstance(c, zntrack.Node)
assert isinstance(x3, zntrack.Node)

assert x3.c == "a-ba-b-c"
12 changes: 12 additions & 0 deletions zntrack/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def run(self):
self.c = self.a.c + self.b.c


class AddNodes2(zntrack.Node):
"""Add two nodes."""

a: AddNumbers = zntrack.deps()
b: AddNumbers = zntrack.deps()
c = zntrack.outs()

def run(self):
"""Add two nodes."""
self.c = self.a.outs + self.b.outs


class AddNodeAttributes(zntrack.Node):
"""Add two node attributes."""

Expand Down

0 comments on commit f7d2d7c

Please sign in to comment.