Skip to content

Commit

Permalink
Merge pull request #4011 from google:nnx-path-filter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645354554
  • Loading branch information
Flax Authors authored and cgarciae committed Jun 24, 2024
2 parents 54fe15f + 5ee8e98 commit 3781f2a
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 3 deletions.
16 changes: 16 additions & 0 deletions docs/api_reference/flax.nnx/filterlib.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
filterlib
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: flax.nnx.filterlib.to_predicate

.. autoclass:: WithTag
.. autoclass:: PathContains
.. autoclass:: OfType
.. autoclass:: Any
.. autoclass:: All
.. autoclass:: Not
.. autoclass:: Everything
.. autoclass:: Nothing
1 change: 1 addition & 0 deletions docs/api_reference/flax.nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/i
nn/index
rnglib
spmd
filterlib
state
training/index
transforms
Expand Down
147 changes: 147 additions & 0 deletions docs/nnx/filters_guide.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Filters\n",
"\n",
"Filters are used extensively in NNX as a way to create `State` groups in APIs\n",
"such as `nnx.split` and `nnx.state`. For example:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
" )\n",
"})\n",
"batch_stats = State({\n",
" 'b': VariableState(\n",
" type=BatchStat,\n",
" value=True\n",
" )\n",
"})\n"
]
}
],
"source": [
"from flax import nnx\n",
"\n",
"class Foo(nnx.Module):\n",
" def __init__(self):\n",
" self.a = nnx.Param(0)\n",
" self.b = nnx.BatchStat(True)\n",
"\n",
"foo = Foo()\n",
"\n",
"graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)\n",
"\n",
"print(f'{params = }')\n",
"print(f'{batch_stats = }')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However this begs the questions:\n",
"\n",
"* What is a Filter?\n",
"* Why are types, such as `Param` or `BatchStat`, Filters?\n",
"* How is `State` grouped / filtered?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## The Filter Protocol\n",
"\n",
"In general Filter are predicate functions of the form:\n",
"\n",
"```python\n",
"\n",
"def f(path: tuple[Key, ...], value: Any) -> bool:\n",
"\n",
"```\n",
"where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n",
"\n",
"Types are obviously not functions. The reason why types are used as Filters is because internally they are converted to functions. For example, `Param` is roughly converted to a function like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def is_param(path, value) -> bool:\n",
" return isinstance(value, nnx.Param) or (\n",
" hasattr(value, 'type') and issubclass(value.type, nnx.Param)\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Such a function matches any value that is an instance of `Param` or any value that hash a `type` attribute that is an instance of a subclass of `Param`. Internally NNX uses `OfType` which creates a callable of the form for a given type."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"is_param((), nnx.Param(0)) = True\n",
"is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n"
]
}
],
"source": [
"is_param = nnx.OfType(nnx.Param)\n",
"\n",
"print(f'{is_param((), nnx.Param(0)) = }')\n",
"print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
8 changes: 8 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@
from .nnx import helpers as helpers
from .nnx import compat as compat
from .nnx import traversals as traversals
from .nnx import filterlib as filterlib

from .nnx.filterlib import WithTag as WithTag
from .nnx.filterlib import PathContains as PathContains
from .nnx.filterlib import OfType as OfType
from .nnx.filterlib import Any as Any
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.filterlib import Everything as Everything
from .nnx.filterlib import Nothing as Nothing
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphState as GraphState
from .nnx.object import Object as Object
Expand Down
58 changes: 55 additions & 3 deletions flax/nnx/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import builtins
import dataclasses
from flax.typing import PathParts
from flax.typing import Key, PathParts
import typing as tp

if tp.TYPE_CHECKING:
Expand Down Expand Up @@ -59,15 +59,22 @@ def to_predicate(filter: Filter) -> Predicate:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class WithTag:
tag: str

def __call__(self, path: PathParts, x: tp.Any):
return isinstance(x, _HasTag) and x.tag == self.tag

@dataclasses.dataclass(frozen=True)
class PathContains:
key: Key

@dataclasses.dataclass
def __call__(self, path: PathParts, x: tp.Any):
return self.key in path


@dataclasses.dataclass(frozen=True)
class OfType:
type: type

Expand All @@ -86,6 +93,15 @@ def __init__(self, *filters: Filter):
def __call__(self, path: PathParts, x: tp.Any):
return any(predicate(path, x) for predicate in self.predicates)

def __repr__(self):
return f'Any({", ".join(map(repr, self.predicates))})'

def __eq__(self, other):
return isinstance(other, Any) and self.predicates == other.predicates

def __hash__(self):
return hash(self.predicates)


class All:
def __init__(self, *filters: Filter):
Expand All @@ -96,6 +112,15 @@ def __init__(self, *filters: Filter):
def __call__(self, path: PathParts, x: tp.Any):
return all(predicate(path, x) for predicate in self.predicates)

def __repr__(self):
return f'All({", ".join(map(repr, self.predicates))})'

def __eq__(self, other):
return isinstance(other, All) and self.predicates == other.predicates

def __hash__(self):
return hash(self.predicates)


class Not:
def __init__(self, collection_filter: Filter, /):
Expand All @@ -104,12 +129,39 @@ def __init__(self, collection_filter: Filter, /):
def __call__(self, path: PathParts, x: tp.Any):
return not self.predicate(path, x)

def __repr__(self):
return f'Not({self.predicate!r})'

def __eq__(self, other):
return isinstance(other, Not) and self.predicate == other.predicate

def __hash__(self):
return hash(self.predicate)


class Everything:
def __call__(self, path: PathParts, x: tp.Any):
return True

def __repr__(self):
return 'Everything()'

def __eq__(self, other):
return isinstance(other, Everything)

def __hash__(self):
return hash(Everything)


class Nothing:
def __call__(self, path: PathParts, x: tp.Any):
return False

def __repr__(self):
return 'Nothing()'

def __eq__(self, other):
return isinstance(other, Nothing)

def __hash__(self):
return hash(Nothing)
33 changes: 33 additions & 0 deletions flax/nnx/tests/filters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from absl.testing import absltest

from flax import nnx


class TestFilters(absltest.TestCase):
def test_path_contains(self):
class Model(nnx.Module):
def __init__(self, rngs):
self.backbone = nnx.Linear(2, 3, rngs=rngs)
self.head = nnx.Linear(3, 10, rngs=rngs)

model = Model(nnx.Rngs(0))

head_state = nnx.state(model, nnx.PathContains('head'))

self.assertIn('head', head_state)
self.assertNotIn('backbone', head_state)

0 comments on commit 3781f2a

Please sign in to comment.