diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 92169083d6..caee218384 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -27,6 +27,7 @@ from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not from .nnx.graph import GraphDef as GraphDef +from .nnx.graph import GraphState as GraphState from .nnx.object import Object as Object from .nnx.helpers import Dict as Dict from .nnx.helpers import List as List diff --git a/flax/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py index 08ecf15c50..b3749c920a 100644 --- a/flax/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -26,21 +26,13 @@ import numpy as np import typing_extensions as tpe -from flax.nnx.nnx import ( - filterlib, - reprlib, -) +from flax.nnx.nnx import filterlib, reprlib from flax.nnx.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.nnx.nnx.state import ( - FlatState, - State, - StateLeaf, - is_state_leaf, -) +from flax.nnx.nnx.state import FlatState, State from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key, PathParts @@ -58,16 +50,13 @@ Leaf = tp.TypeVar('Leaf') AuxData = tp.TypeVar('AuxData') -Updates = tp.Union[ - A, - 'GraphDef[A]', - tuple['GraphDef[A]', State], - tuple['GraphDef[A]', tuple[State, ...]], - State, - tuple[State, ...], -] +StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array] +GraphState = State[Key, StateLeaf] +GraphFlatState = FlatState[StateLeaf] -NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] + +def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: + return isinstance(x, (VariableState, np.ndarray, jax.Array)) @dataclasses.dataclass @@ -80,7 +69,7 @@ class GraphContext(threading.local): GRAPH_CONTEXT = GraphContext() -def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: +def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: return isinstance(x, (Variable, np.ndarray, jax.Array)) @@ -349,13 +338,13 @@ def __eq__(self, other): return isinstance(other, GraphDef) and self.nodedef == other.nodedef def apply( - self, state: State, *states: State - ) -> ApplyCaller[tuple['GraphDef[Node]', State]]: + self, state: GraphState, *states: GraphState + ) -> ApplyCaller[tuple['GraphDef[Node]', GraphState]]: accessor = DelayedAccessor() def _apply( accessor: DelayedAccessor, *args, **kwargs - ) -> tuple[tp.Any, tuple[GraphDef[Node], State]]: + ) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]: module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) @@ -363,10 +352,6 @@ def _apply( return CallableProxy(_apply, accessor) # type: ignore - def make_empty(self) -> Node: - return merge(self, State({})) - - def _graphdef_flatten(graphdef: GraphDef[Node]): # refmap is opaque, we don't propagate it static = (graphdef.nodedef, graphdef.index_mapping) @@ -392,7 +377,7 @@ def flatten( /, *, idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[GraphDef[Node], State, RefMap[tp.Any, Index]]: +) -> tuple[GraphDef[Node], GraphState, RefMap[tp.Any, Index]]: refmap = RefMap[tp.Any, Index]() flat_state: dict[PathParts, StateLeaf] = {} nodedef = _graph_flatten((), refmap, flat_state, x) @@ -402,7 +387,7 @@ def flatten( else: index_to_index = None graphdef = GraphDef(nodedef, index_to_index) - return graphdef, State.from_flat_path(flat_state), refmap + return graphdef, GraphState.from_flat_path(flat_state), refmap def _graph_flatten( @@ -462,7 +447,7 @@ def _graph_flatten( def unflatten( graphdef: GraphDef[Node], - state: State, + state: GraphState, /, *, idxmap: dict[Index, tp.Any] | None = None, @@ -516,7 +501,7 @@ def _graph_unflatten( node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[Key, NodeLeaf | Node] = {} + children: dict[Key, StateLeaf | Node] = {} # NOTE: we could allw adding new StateLeafs here if unkown_keys := set(state) - set(nodedef.attributes): @@ -650,20 +635,22 @@ def _get_children(): def graph_pop( node: tp.Any, filters: tuple[filterlib.Filter, ...], -) -> tuple[State, ...]: +) -> tuple[GraphState, ...]: id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates) + flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) - return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return tuple( + GraphState.from_flat_path(flat_state) for flat_state in flat_states + ) def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, - flat_states: tuple[FlatState, ...], + flat_states: tuple[GraphFlatState, ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): @@ -888,12 +875,12 @@ def __eq__(self, other): return isinstance(other, UpdateContext) @tp.overload - def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ... + def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, / - ) -> tuple[GraphDef[A], State]: ... + ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( @@ -903,11 +890,11 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ... + ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... def split( self, node: A, *filters: filterlib.Filter - ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: + ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous @@ -977,7 +964,7 @@ def split( ) graphdef, state, refmap = flatten(node, idxmap=self.idxmap) - states: State | tuple[State, ...] + states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = (state,) elif len(filters) == 1: @@ -997,15 +984,15 @@ def split( def merge( self, graphdef: GraphDef[A], - state: State, - *states: State, + state: GraphState, + *states: GraphState, ) -> A: """merge""" if self.refmap is None: raise ValueError('Cannot update a graphdef without refmap.') if states: - state = State.merge(state, *states) + state = GraphState.merge(state, *states) if graphdef.index_mapping is None: node, self.idxmap = unflatten(graphdef, state) @@ -1163,7 +1150,7 @@ def current_update_context(tag: str) -> UpdateContext: @tp.overload -def split(graph_node: A, /) -> tuple[GraphDef[A], State]: ... +def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @tp.overload @@ -1171,7 +1158,7 @@ def split( graph_node: A, first: filterlib.Filter, /, -) -> tuple[GraphDef[A], State]: ... +) -> tuple[GraphDef[A], GraphState]: ... @tp.overload @@ -1181,12 +1168,12 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ... +) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... def split( node: A, *filters: filterlib.Filter -) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: +) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous @@ -1257,7 +1244,7 @@ def split( """ graphdef, state, _ = flatten(node) - states: State | tuple[State, ...] + states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = (state,) elif len(filters) == 1: @@ -1270,9 +1257,9 @@ def split( def merge( graphdef: GraphDef[A], - state: State, + state: GraphState, /, - *states: State, + *states: GraphState, ) -> A: """The inverse of :func:`split`. @@ -1308,25 +1295,25 @@ def merge( states: Additional :class:`State` objects. """ if states: - state = State.merge(state, *states) + state = GraphState.merge(state, *states) node, _ = unflatten(graphdef, state) return node -def update(node, state: State, /, *states: State) -> None: +def update(node, state: GraphState, /, *states: GraphState) -> None: if states: - state = State.merge(state, *states) + state = GraphState.merge(state, *states) _graph_update_dynamic(node, state.raw_mapping) @tp.overload -def state(node, /) -> State: ... +def state(node, /) -> GraphState: ... @tp.overload -def state(node, first: filterlib.Filter, /) -> State: ... +def state(node, first: filterlib.Filter, /) -> GraphState: ... @tp.overload @@ -1336,16 +1323,16 @@ def state( second: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[State, ...]: ... +) -> tuple[GraphState, ...]: ... def state( node, *filters: filterlib.Filter, -) -> tp.Union[State, tuple[State, ...]]: +) -> tp.Union[GraphState, tuple[GraphState, ...]]: state = flatten(node)[1] - states: State | tuple[State, ...] + states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = state elif len(filters) == 1: @@ -1366,7 +1353,7 @@ def pop( node, filter: filterlib.Filter, /, -) -> State: ... +) -> GraphState: ... @tp.overload @@ -1376,17 +1363,19 @@ def pop( filter2: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[State, ...]: ... +) -> tuple[GraphState, ...]: ... -def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]: +def pop( + node, *filters: filterlib.Filter +) -> tp.Union[GraphState, tuple[GraphState, ...]]: if len(filters) == 0: raise ValueError('Expected at least one filter') id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates) + flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates) _graph_pop( node=node, id_to_index=id_to_index, @@ -1394,7 +1383,9 @@ def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]: flat_states=flat_states, predicates=predicates, ) - states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + states = tuple( + GraphState.from_flat_path(flat_state) for flat_state in flat_states + ) if len(states) == 1: return states[0] diff --git a/flax/nnx/nnx/module.py b/flax/nnx/nnx/module.py index 5488e80af8..552943b649 100644 --- a/flax/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -26,13 +26,14 @@ from flax.nnx.nnx import variables as variableslib from flax.nnx.nnx.graph import GraphDef from flax.nnx.nnx.object import Object, ObjectMeta -from flax.nnx.nnx.state import State, StateLeaf +from flax.nnx.nnx.graph import GraphState, StateLeaf +from flax.nnx.nnx.state import State from flax.typing import Path, PathParts A = tp.TypeVar('A') B = tp.TypeVar('B') M = tp.TypeVar('M', bound='Module') -S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]]) +S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]]) V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any]) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) diff --git a/flax/nnx/nnx/state.py b/flax/nnx/nnx/state.py index 2bf055ec1a..4ed8feef50 100644 --- a/flax/nnx/nnx/state.py +++ b/flax/nnx/nnx/state.py @@ -27,27 +27,21 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Mapping +from collections.abc import MutableMapping import typing as tp -import typing_extensions as tpe import jax import jax.tree_util as jtu -import numpy as np from flax.nnx.nnx import traversals from flax.nnx.nnx import filterlib, reprlib -from flax.nnx.nnx.variables import VariableState -from flax.typing import Key, PathParts +from flax.typing import PathParts A = tp.TypeVar('A') +K = tp.TypeVar('K', bound=tp.Hashable) +V = tp.TypeVar('V') -StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array] -FlatState = Mapping[PathParts, StateLeaf] - - -def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: - return isinstance(x, (VariableState, np.ndarray, jax.Array)) +FlatState = dict[PathParts, V] class NestedStateRepr(reprlib.Representable): @@ -71,12 +65,13 @@ def __penzai_repr__(self, path, subtree_renderer): # Render as the dictionary itself at the same path. return subtree_renderer(children, path=path) -class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable): + +class State(MutableMapping[K, V], reprlib.Representable): def __init__( self, mapping: tp.Union[ - Mapping[Key, Mapping | StateLeaf], - tp.Iterator[tuple[Key, Mapping | StateLeaf]], + tp.Mapping[K, tp.Mapping | V], + tp.Iterator[tuple[K, tp.Mapping | V]], ], /, *, @@ -98,35 +93,35 @@ def __init__( super().__setattr__('_mapping', _mapping) @property - def raw_mapping(self) -> tp.Mapping[Key, tp.Mapping[Key, tp.Any] | StateLeaf]: + def raw_mapping(self) -> tp.Mapping[K, tp.Mapping[K, tp.Any] | V]: return self._mapping # type: ignore def __contains__(self, key) -> bool: return key in self._mapping - def __getitem__(self, key: Key) -> State | StateLeaf: + def __getitem__(self, key: K) -> State | V: # type: ignore value = self._mapping[key] if isinstance(value, tp.Mapping): return State(value, _copy=False) return value - def __getattr__(self, key: Key) -> State | StateLeaf: + def __getattr__(self, key: K) -> State | V: # type: ignore[misc] if '_mapping' not in vars(self) or key not in self._mapping: raise AttributeError(f"No attribute '{key}' in State") return self[key] - def __setitem__(self, key: Key, value: State | StateLeaf) -> None: + def __setitem__(self, key: K, value: State | V) -> None: if isinstance(value, State): self._mapping[key] = value._mapping else: self._mapping[key] = value - __setattr__ = __setitem__ + __setattr__ = __setitem__ # type: ignore - def __delitem__(self, key: Key) -> None: + def __delitem__(self, key: K) -> None: del self._mapping[key] - def __iter__(self) -> tp.Iterator[Key]: + def __iter__(self) -> tp.Iterator[K]: return iter(self._mapping) def __len__(self) -> int: @@ -149,24 +144,22 @@ def __penzai_repr__(self, path, subtree_renderer): v = NestedStateRepr(v) children[k] = v return pz_repr_lib.render_dictionary_wrapper( - object_type=type(self), - wrapped_dict=children, - path=path, - subtree_renderer=subtree_renderer, + object_type=type(self), + wrapped_dict=children, + path=path, + subtree_renderer=subtree_renderer, ) - def flat_state(self) -> FlatState: - return traversals.flatten_mapping(self._mapping) # type: ignore + def flat_state(self) -> FlatState[V]: + return traverse_util.flatten_dict(self._mapping) # type: ignore @classmethod - def from_flat_path( - cls, flat_state: tp.Mapping[PathParts, StateLeaf], / - ) -> State: + def from_flat_path(cls, flat_state: tp.Mapping[PathParts, V], /) -> State: nested_state = traversals.unflatten_mapping(flat_state) return cls(nested_state) @tp.overload - def split(self, first: filterlib.Filter, /) -> 'State': ... + def split(self, first: filterlib.Filter, /) -> State[K, V]: ... @tp.overload def split( @@ -175,11 +168,11 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: ... + ) -> tuple[State[K, V], ...]: ... def split( self, first: filterlib.Filter, /, *filters: filterlib.Filter - ) -> tp.Union['State', tuple['State', ...]]: + ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: filters = (first, *filters) *states_, rest = _split_state(self, *filters) @@ -201,7 +194,7 @@ def filter( self, first: filterlib.Filter, /, - ) -> 'State': ... + ) -> State[K, V]: ... @tp.overload def filter( @@ -210,14 +203,14 @@ def filter( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: ... + ) -> tuple[State[K, V], ...]: ... def filter( self, first: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tp.Union['State', tuple['State', ...]]: + ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: *states_, _rest = _split_state(self, first, *filters) assert len(states_) == len(filters) + 1 @@ -231,25 +224,25 @@ def filter( return states # type: ignore[bad-return-type] @staticmethod - def merge(state: 'State', /, *states: 'State') -> 'State': + def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: states = (state, *states) if len(states) == 1: return states[0] - new_state: FlatState = {} + new_state: FlatState[V] = {} for state in states: new_state.update(state.flat_state()) # type: ignore[attribute-error] # pytype is wrong here return State.from_flat_path(new_state) - def __or__(self, other: 'State') -> 'State': + def __or__(self, other: State[K, V]) -> State[K, V]: if not other: return self return State.merge(self, other) - def __sub__(self, other: 'State') -> 'State': + def __sub__(self, other: State[K, V]) -> State[K, V]: if not other: return self @@ -267,8 +260,8 @@ def _state_flatten_with_keys(x: State): def _state_unflatten( - static: tuple[Key, ...], - leaves: tuple[StateLeaf, ...] | tuple[dict[Key, StateLeaf]], + static: tuple[K, ...], + leaves: tuple[V, ...] | tuple[dict[K, V]], ): return State(zip(static, leaves)) @@ -281,9 +274,9 @@ def _state_unflatten( def _split_state( - state: State, + state: State[K, V], *filters: filterlib.Filter, -) -> tuple[State, ...]: +) -> tuple[State[K, V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] @@ -298,7 +291,7 @@ def _split_state( # we have n + 1 states, where n is the number of predicates # the last state is for values that don't match any predicate - flat_states: tuple[FlatState, ...] = tuple( + flat_states: tuple[FlatState[V], ...] = tuple( {} for _ in range(len(predicates) + 1) )