Skip to content

Commit

Permalink
[nnx] make State generic
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jun 4, 2024
1 parent 40c78bd commit 89ced50
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 104 deletions.
117 changes: 54 additions & 63 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,13 @@
import numpy as np
import typing_extensions as tpe

from flax.nnx.nnx import (
filterlib,
reprlib,
)
from flax.nnx.nnx import filterlib, reprlib
from flax.nnx.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
DelayedAccessor,
)
from flax.nnx.nnx.state import (
FlatState,
State,
StateLeaf,
is_state_leaf,
)
from flax.nnx.nnx.state import FlatState, State
from flax.nnx.nnx.variables import Variable, VariableState
from flax.typing import Key, PathParts

Expand All @@ -58,16 +50,13 @@
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')

Updates = tp.Union[
A,
'GraphDef[A]',
tuple['GraphDef[A]', State],
tuple['GraphDef[A]', tuple[State, ...]],
State,
tuple[State, ...],
]
StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]

NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array]

def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (VariableState, np.ndarray, jax.Array))


@dataclasses.dataclass
Expand All @@ -80,7 +69,7 @@ class GraphContext(threading.local):
GRAPH_CONTEXT = GraphContext()


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (Variable, np.ndarray, jax.Array))


Expand Down Expand Up @@ -349,24 +338,20 @@ def __eq__(self, other):
return isinstance(other, GraphDef) and self.nodedef == other.nodedef

def apply(
self, state: State, *states: State
) -> ApplyCaller[tuple['GraphDef[Node]', State]]:
self, state: GraphState, *states: GraphState
) -> ApplyCaller[tuple['GraphDef[Node]', GraphState]]:
accessor = DelayedAccessor()

def _apply(
accessor: DelayedAccessor, *args, **kwargs
) -> tuple[tp.Any, tuple[GraphDef[Node], State]]:
) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]:
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)[:2]

return CallableProxy(_apply, accessor) # type: ignore

def make_empty(self) -> Node:
return merge(self, State({}))


def _graphdef_flatten(graphdef: GraphDef[Node]):
# refmap is opaque, we don't propagate it
static = (graphdef.nodedef, graphdef.index_mapping)
Expand All @@ -392,7 +377,7 @@ def flatten(
/,
*,
idxmap: dict[Index, tp.Any] | None = None,
) -> tuple[GraphDef[Node], State, RefMap[tp.Any, Index]]:
) -> tuple[GraphDef[Node], GraphState, RefMap[tp.Any, Index]]:
refmap = RefMap[tp.Any, Index]()
flat_state: dict[PathParts, StateLeaf] = {}
nodedef = _graph_flatten((), refmap, flat_state, x)
Expand All @@ -402,7 +387,7 @@ def flatten(
else:
index_to_index = None
graphdef = GraphDef(nodedef, index_to_index)
return graphdef, State.from_flat_path(flat_state), refmap
return graphdef, GraphState.from_flat_path(flat_state), refmap


def _graph_flatten(
Expand Down Expand Up @@ -462,7 +447,7 @@ def _graph_flatten(

def unflatten(
graphdef: GraphDef[Node],
state: State,
state: GraphState,
/,
*,
idxmap: dict[Index, tp.Any] | None = None,
Expand Down Expand Up @@ -516,7 +501,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
children: dict[Key, StateLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
Expand Down Expand Up @@ -650,20 +635,22 @@ def _get_children():
def graph_pop(
node: tp.Any,
filters: tuple[filterlib.Filter, ...],
) -> tuple[State, ...]:
) -> tuple[GraphState, ...]:
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
)


def _graph_pop(
node: tp.Any,
id_to_index: dict[int, Index],
path_parts: PathParts,
flat_states: tuple[FlatState, ...],
flat_states: tuple[GraphFlatState, ...],
predicates: tuple[filterlib.Predicate, ...],
) -> None:
if not is_node(node):
Expand Down Expand Up @@ -888,12 +875,12 @@ def __eq__(self, other):
return isinstance(other, UpdateContext)

@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ...
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...

@tp.overload
def split(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], State]: ...
) -> tuple[GraphDef[A], GraphState]: ...

@tp.overload
def split(
Expand All @@ -903,11 +890,11 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...

def split(
self, node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
Expand Down Expand Up @@ -977,7 +964,7 @@ def split(
)
graphdef, state, refmap = flatten(node, idxmap=self.idxmap)

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = (state,)
elif len(filters) == 1:
Expand All @@ -997,15 +984,15 @@ def split(
def merge(
self,
graphdef: GraphDef[A],
state: State,
*states: State,
state: GraphState,
*states: GraphState,
) -> A:
"""merge"""
if self.refmap is None:
raise ValueError('Cannot update a graphdef without refmap.')

if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

if graphdef.index_mapping is None:
node, self.idxmap = unflatten(graphdef, state)
Expand Down Expand Up @@ -1163,15 +1150,15 @@ def current_update_context(tag: str) -> UpdateContext:


@tp.overload
def split(graph_node: A, /) -> tuple[GraphDef[A], State]: ...
def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...


@tp.overload
def split(
graph_node: A,
first: filterlib.Filter,
/,
) -> tuple[GraphDef[A], State]: ...
) -> tuple[GraphDef[A], GraphState]: ...


@tp.overload
Expand All @@ -1181,12 +1168,12 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...


def split(
node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
Expand Down Expand Up @@ -1257,7 +1244,7 @@ def split(
"""
graphdef, state, _ = flatten(node)

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = (state,)
elif len(filters) == 1:
Expand All @@ -1270,9 +1257,9 @@ def split(

def merge(
graphdef: GraphDef[A],
state: State,
state: GraphState,
/,
*states: State,
*states: GraphState,
) -> A:
"""The inverse of :func:`split`.
Expand Down Expand Up @@ -1308,25 +1295,25 @@ def merge(
states: Additional :class:`State` objects.
"""
if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

node, _ = unflatten(graphdef, state)
return node


def update(node, state: State, /, *states: State) -> None:
def update(node, state: GraphState, /, *states: GraphState) -> None:
if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

_graph_update_dynamic(node, state.raw_mapping)


@tp.overload
def state(node, /) -> State: ...
def state(node, /) -> GraphState: ...


@tp.overload
def state(node, first: filterlib.Filter, /) -> State: ...
def state(node, first: filterlib.Filter, /) -> GraphState: ...


@tp.overload
Expand All @@ -1336,16 +1323,16 @@ def state(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]: ...
) -> tuple[GraphState, ...]: ...


def state(
node,
*filters: filterlib.Filter,
) -> tp.Union[State, tuple[State, ...]]:
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
state = flatten(node)[1]

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = state
elif len(filters) == 1:
Expand All @@ -1366,7 +1353,7 @@ def pop(
node,
filter: filterlib.Filter,
/,
) -> State: ...
) -> GraphState: ...


@tp.overload
Expand All @@ -1376,25 +1363,29 @@ def pop(
filter2: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]: ...
) -> tuple[GraphState, ...]: ...


def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]:
def pop(
node, *filters: filterlib.Filter
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
if len(filters) == 0:
raise ValueError('Expected at least one filter')

id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
_graph_pop(
node=node,
id_to_index=id_to_index,
path_parts=path_parts,
flat_states=flat_states,
predicates=predicates,
)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
states = tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
)

if len(states) == 1:
return states[0]
Expand Down
5 changes: 3 additions & 2 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.graph import GraphDef
from flax.nnx.nnx.object import Object, ObjectMeta
from flax.nnx.nnx.state import State, StateLeaf
from flax.nnx.nnx.graph import GraphState, StateLeaf
from flax.nnx.nnx.state import State
from flax.typing import Path, PathParts

A = tp.TypeVar('A')
B = tp.TypeVar('B')
M = tp.TypeVar('M', bound='Module')
S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]])
S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]])
V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any])
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])

Expand Down
Loading

0 comments on commit 89ced50

Please sign in to comment.