diff --git a/docs/api_reference/flax.nnx/filterlib.rst b/docs/api_reference/flax.nnx/filterlib.rst new file mode 100644 index 0000000000..f3a7e7dc37 --- /dev/null +++ b/docs/api_reference/flax.nnx/filterlib.rst @@ -0,0 +1,16 @@ +filterlib +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autofunction:: flax.nnx.filterlib.to_predicate + +.. autoclass:: WithTag +.. autoclass:: PathContains +.. autoclass:: OfType +.. autoclass:: Any +.. autoclass:: All +.. autoclass:: Not +.. autoclass:: Everything +.. autoclass:: Nothing \ No newline at end of file diff --git a/docs/api_reference/flax.nnx/index.rst b/docs/api_reference/flax.nnx/index.rst index 957c99567d..1a4ed808a9 100644 --- a/docs/api_reference/flax.nnx/index.rst +++ b/docs/api_reference/flax.nnx/index.rst @@ -11,6 +11,7 @@ Experimental API. See the `NNX page bool:\n", + "\n", + "```\n", + "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", + "\n", + "Types are obviously not functions. The reason why types are used as Filters is because internally they are converted to functions. For example, `Param` is roughly converted to a function like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def is_param(path, value) -> bool:\n", + " return isinstance(value, nnx.Param) or (\n", + " hasattr(value, 'type') and issubclass(value.type, nnx.Param)\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Such a function matches any value that is an instance of `Param` or any value that hash a `type` attribute that is an instance of a subclass of `Param`. Internally NNX uses `OfType` which creates a callable of the form for a given type." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "is_param((), nnx.Param(0)) = True\n", + "is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n" + ] + } + ], + "source": [ + "is_param = nnx.OfType(nnx.Param)\n", + "\n", + "print(f'{is_param((), nnx.Param(0)) = }')\n", + "print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 575409bd8b..37b10d4b0f 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -24,8 +24,16 @@ from .nnx import helpers as helpers from .nnx import compat as compat from .nnx import traversals as traversals +from .nnx import filterlib as filterlib + +from .nnx.filterlib import WithTag as WithTag +from .nnx.filterlib import PathContains as PathContains +from .nnx.filterlib import OfType as OfType +from .nnx.filterlib import Any as Any from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not +from .nnx.filterlib import Everything as Everything +from .nnx.filterlib import Nothing as Nothing from .nnx.graph import GraphDef as GraphDef from .nnx.graph import GraphState as GraphState from .nnx.object import Object as Object diff --git a/flax/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py index 4978cf0204..81844f415e 100644 --- a/flax/nnx/nnx/filterlib.py +++ b/flax/nnx/nnx/filterlib.py @@ -14,7 +14,7 @@ import builtins import dataclasses -from flax.typing import PathParts +from flax.typing import Key, PathParts import typing as tp if tp.TYPE_CHECKING: @@ -59,15 +59,22 @@ def to_predicate(filter: Filter) -> Predicate: raise TypeError(f'Invalid collection filter: {filter:!r}. ') -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, _HasTag) and x.tag == self.tag +@dataclasses.dataclass(frozen=True) +class PathContains: + key: Key -@dataclasses.dataclass + def __call__(self, path: PathParts, x: tp.Any): + return self.key in path + + +@dataclasses.dataclass(frozen=True) class OfType: type: type @@ -86,6 +93,15 @@ def __init__(self, *filters: Filter): def __call__(self, path: PathParts, x: tp.Any): return any(predicate(path, x) for predicate in self.predicates) + def __repr__(self): + return f'Any({", ".join(map(repr, self.predicates))})' + + def __eq__(self, other): + return isinstance(other, Any) and self.predicates == other.predicates + + def __hash__(self): + return hash(self.predicates) + class All: def __init__(self, *filters: Filter): @@ -96,6 +112,15 @@ def __init__(self, *filters: Filter): def __call__(self, path: PathParts, x: tp.Any): return all(predicate(path, x) for predicate in self.predicates) + def __repr__(self): + return f'All({", ".join(map(repr, self.predicates))})' + + def __eq__(self, other): + return isinstance(other, All) and self.predicates == other.predicates + + def __hash__(self): + return hash(self.predicates) + class Not: def __init__(self, collection_filter: Filter, /): @@ -104,12 +129,39 @@ def __init__(self, collection_filter: Filter, /): def __call__(self, path: PathParts, x: tp.Any): return not self.predicate(path, x) + def __repr__(self): + return f'Not({self.predicate!r})' + + def __eq__(self, other): + return isinstance(other, Not) and self.predicate == other.predicate + + def __hash__(self): + return hash(self.predicate) + class Everything: def __call__(self, path: PathParts, x: tp.Any): return True + def __repr__(self): + return 'Everything()' + + def __eq__(self, other): + return isinstance(other, Everything) + + def __hash__(self): + return hash(Everything) + class Nothing: def __call__(self, path: PathParts, x: tp.Any): return False + + def __repr__(self): + return 'Nothing()' + + def __eq__(self, other): + return isinstance(other, Nothing) + + def __hash__(self): + return hash(Nothing) diff --git a/flax/nnx/tests/filters_test.py b/flax/nnx/tests/filters_test.py new file mode 100644 index 0000000000..5fbe9758ce --- /dev/null +++ b/flax/nnx/tests/filters_test.py @@ -0,0 +1,33 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest + +from flax import nnx + + +class TestFilters(absltest.TestCase): + def test_path_contains(self): + class Model(nnx.Module): + def __init__(self, rngs): + self.backbone = nnx.Linear(2, 3, rngs=rngs) + self.head = nnx.Linear(3, 10, rngs=rngs) + + model = Model(nnx.Rngs(0)) + + head_state = nnx.state(model, nnx.PathContains('head')) + + self.assertIn('head', head_state) + self.assertNotIn('backbone', head_state) \ No newline at end of file