Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] make State generic #3964

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphState as GraphState
from .nnx.object import Object as Object
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
Expand Down
117 changes: 54 additions & 63 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,13 @@
import numpy as np
import typing_extensions as tpe

from flax.nnx.nnx import (
filterlib,
reprlib,
)
from flax.nnx.nnx import filterlib, reprlib
from flax.nnx.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
DelayedAccessor,
)
from flax.nnx.nnx.state import (
FlatState,
State,
StateLeaf,
is_state_leaf,
)
from flax.nnx.nnx.state import FlatState, State
from flax.nnx.nnx.variables import Variable, VariableState
from flax.typing import Key, PathParts

Expand All @@ -58,16 +50,13 @@
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')

Updates = tp.Union[
A,
'GraphDef[A]',
tuple['GraphDef[A]', State],
tuple['GraphDef[A]', tuple[State, ...]],
State,
tuple[State, ...],
]
StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]
Comment on lines +53 to +55
Copy link
Collaborator

@chiamp chiamp Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for clarification, GraphState is a State, but StateLeaf, FlatState and GraphFlatState are not State's?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, State[K, V] is now a generic mapping, GraphState is a type alias for a specific type of State as returned by the graph APIs, and LeafState is a union describing the possible types of leaves you find in GraphState.


NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array]

def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (VariableState, np.ndarray, jax.Array))


@dataclasses.dataclass
Expand All @@ -80,7 +69,7 @@ class GraphContext(threading.local):
GRAPH_CONTEXT = GraphContext()


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (Variable, np.ndarray, jax.Array))


Expand Down Expand Up @@ -349,24 +338,20 @@ def __eq__(self, other):
return isinstance(other, GraphDef) and self.nodedef == other.nodedef

def apply(
self, state: State, *states: State
) -> ApplyCaller[tuple['GraphDef[Node]', State]]:
self, state: GraphState, *states: GraphState
) -> ApplyCaller[tuple['GraphDef[Node]', GraphState]]:
accessor = DelayedAccessor()

def _apply(
accessor: DelayedAccessor, *args, **kwargs
) -> tuple[tp.Any, tuple[GraphDef[Node], State]]:
) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]:
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)[:2]

return CallableProxy(_apply, accessor) # type: ignore

def make_empty(self) -> Node:
return merge(self, State({}))


def _graphdef_flatten(graphdef: GraphDef[Node]):
# refmap is opaque, we don't propagate it
static = (graphdef.nodedef, graphdef.index_mapping)
Expand All @@ -392,7 +377,7 @@ def flatten(
/,
*,
idxmap: dict[Index, tp.Any] | None = None,
) -> tuple[GraphDef[Node], State, RefMap[tp.Any, Index]]:
) -> tuple[GraphDef[Node], GraphState, RefMap[tp.Any, Index]]:
refmap = RefMap[tp.Any, Index]()
flat_state: dict[PathParts, StateLeaf] = {}
nodedef = _graph_flatten((), refmap, flat_state, x)
Expand All @@ -402,7 +387,7 @@ def flatten(
else:
index_to_index = None
graphdef = GraphDef(nodedef, index_to_index)
return graphdef, State.from_flat_path(flat_state), refmap
return graphdef, GraphState.from_flat_path(flat_state), refmap


def _graph_flatten(
Expand Down Expand Up @@ -462,7 +447,7 @@ def _graph_flatten(

def unflatten(
graphdef: GraphDef[Node],
state: State,
state: GraphState,
/,
*,
idxmap: dict[Index, tp.Any] | None = None,
Expand Down Expand Up @@ -516,7 +501,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
children: dict[Key, StateLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
Expand Down Expand Up @@ -650,20 +635,22 @@ def _get_children():
def graph_pop(
node: tp.Any,
filters: tuple[filterlib.Filter, ...],
) -> tuple[State, ...]:
) -> tuple[GraphState, ...]:
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
)


def _graph_pop(
node: tp.Any,
id_to_index: dict[int, Index],
path_parts: PathParts,
flat_states: tuple[FlatState, ...],
flat_states: tuple[GraphFlatState, ...],
predicates: tuple[filterlib.Predicate, ...],
) -> None:
if not is_node(node):
Expand Down Expand Up @@ -888,12 +875,12 @@ def __eq__(self, other):
return isinstance(other, UpdateContext)

@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ...
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...

@tp.overload
def split(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], State]: ...
) -> tuple[GraphDef[A], GraphState]: ...

@tp.overload
def split(
Expand All @@ -903,11 +890,11 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...

def split(
self, node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
Expand Down Expand Up @@ -977,7 +964,7 @@ def split(
)
graphdef, state, refmap = flatten(node, idxmap=self.idxmap)

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = (state,)
elif len(filters) == 1:
Expand All @@ -997,15 +984,15 @@ def split(
def merge(
self,
graphdef: GraphDef[A],
state: State,
*states: State,
state: GraphState,
*states: GraphState,
) -> A:
"""merge"""
if self.refmap is None:
raise ValueError('Cannot update a graphdef without refmap.')

if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

if graphdef.index_mapping is None:
node, self.idxmap = unflatten(graphdef, state)
Expand Down Expand Up @@ -1163,15 +1150,15 @@ def current_update_context(tag: str) -> UpdateContext:


@tp.overload
def split(graph_node: A, /) -> tuple[GraphDef[A], State]: ...
def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...


@tp.overload
def split(
graph_node: A,
first: filterlib.Filter,
/,
) -> tuple[GraphDef[A], State]: ...
) -> tuple[GraphDef[A], GraphState]: ...


@tp.overload
Expand All @@ -1181,12 +1168,12 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...


def split(
node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
Expand Down Expand Up @@ -1257,7 +1244,7 @@ def split(
"""
graphdef, state, _ = flatten(node)

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = (state,)
elif len(filters) == 1:
Expand All @@ -1270,9 +1257,9 @@ def split(

def merge(
graphdef: GraphDef[A],
state: State,
state: GraphState,
/,
*states: State,
*states: GraphState,
) -> A:
"""The inverse of :func:`split`.

Expand Down Expand Up @@ -1308,25 +1295,25 @@ def merge(
states: Additional :class:`State` objects.
"""
if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

node, _ = unflatten(graphdef, state)
return node


def update(node, state: State, /, *states: State) -> None:
def update(node, state: GraphState, /, *states: GraphState) -> None:
if states:
state = State.merge(state, *states)
state = GraphState.merge(state, *states)

_graph_update_dynamic(node, state.raw_mapping)


@tp.overload
def state(node, /) -> State: ...
def state(node, /) -> GraphState: ...


@tp.overload
def state(node, first: filterlib.Filter, /) -> State: ...
def state(node, first: filterlib.Filter, /) -> GraphState: ...


@tp.overload
Expand All @@ -1336,16 +1323,16 @@ def state(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]: ...
) -> tuple[GraphState, ...]: ...


def state(
node,
*filters: filterlib.Filter,
) -> tp.Union[State, tuple[State, ...]]:
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
state = flatten(node)[1]

states: State | tuple[State, ...]
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
states = state
elif len(filters) == 1:
Expand All @@ -1366,7 +1353,7 @@ def pop(
node,
filter: filterlib.Filter,
/,
) -> State: ...
) -> GraphState: ...


@tp.overload
Expand All @@ -1376,25 +1363,29 @@ def pop(
filter2: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]: ...
) -> tuple[GraphState, ...]: ...


def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]:
def pop(
node, *filters: filterlib.Filter
) -> tp.Union[GraphState, tuple[GraphState, ...]]:
if len(filters) == 0:
raise ValueError('Expected at least one filter')

id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates)
flat_states: tuple[GraphFlatState, ...] = tuple({} for _ in predicates)
_graph_pop(
node=node,
id_to_index=id_to_index,
path_parts=path_parts,
flat_states=flat_states,
predicates=predicates,
)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
states = tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
)

if len(states) == 1:
return states[0]
Expand Down
5 changes: 3 additions & 2 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.graph import GraphDef
from flax.nnx.nnx.object import Object, ObjectMeta
from flax.nnx.nnx.state import State, StateLeaf
from flax.nnx.nnx.graph import GraphState, StateLeaf
from flax.nnx.nnx.state import State
from flax.typing import Path, PathParts

A = tp.TypeVar('A')
B = tp.TypeVar('B')
M = tp.TypeVar('M', bound='Module')
S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]])
S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]])
V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any])
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])

Expand Down
Loading
Loading