From 5ee8e98be89b0febb07b57117952937d276cd529 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 19 Jun 2024 12:37:25 +0100 Subject: [PATCH] [nnx] add PathContains Filter --- flax/nnx/__init__.py | 1 + flax/nnx/nnx/filterlib.py | 9 ++++++++- flax/nnx/tests/filters_test.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 flax/nnx/tests/filters_test.py diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 575409bd8b..555f5dc66e 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -25,6 +25,7 @@ from .nnx import compat as compat from .nnx import traversals as traversals from .nnx.filterlib import All as All +from .nnx.filterlib import PathContains as PathContains from .nnx.filterlib import Not as Not from .nnx.graph import GraphDef as GraphDef from .nnx.graph import GraphState as GraphState diff --git a/flax/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py index 4978cf0204..97b4765afb 100644 --- a/flax/nnx/nnx/filterlib.py +++ b/flax/nnx/nnx/filterlib.py @@ -14,7 +14,7 @@ import builtins import dataclasses -from flax.typing import PathParts +from flax.typing import Key, PathParts import typing as tp if tp.TYPE_CHECKING: @@ -66,6 +66,13 @@ class WithTag: def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, _HasTag) and x.tag == self.tag +@dataclasses.dataclass +class PathContains: + key: Key + + def __call__(self, path: PathParts, x: tp.Any): + return self.key in path + @dataclasses.dataclass class OfType: diff --git a/flax/nnx/tests/filters_test.py b/flax/nnx/tests/filters_test.py new file mode 100644 index 0000000000..5fbe9758ce --- /dev/null +++ b/flax/nnx/tests/filters_test.py @@ -0,0 +1,33 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest + +from flax import nnx + + +class TestFilters(absltest.TestCase): + def test_path_contains(self): + class Model(nnx.Module): + def __init__(self, rngs): + self.backbone = nnx.Linear(2, 3, rngs=rngs) + self.head = nnx.Linear(3, 10, rngs=rngs) + + model = Model(nnx.Rngs(0)) + + head_state = nnx.state(model, nnx.PathContains('head')) + + self.assertIn('head', head_state) + self.assertNotIn('backbone', head_state) \ No newline at end of file