Skip to content

Commit

Permalink
Merge pull request #4007 from google:nnx-fix-grad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645050031
  • Loading branch information
Flax Authors committed Jun 20, 2024
2 parents f62da76 + 1e0a5d1 commit 54fe15f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
5 changes: 4 additions & 1 deletion flax/nnx/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def to_shape_dtype(value):
return value.replace(
raw_value=jax.tree.map(to_shape_dtype, value.raw_value)
)
elif isinstance(value, (np.ndarray, jax.Array)):
elif (
isinstance(value, (np.ndarray, jax.Array))
and np.prod(value.shape) > 1
):
return Array(value.shape, value.dtype)
return value

Expand Down
16 changes: 8 additions & 8 deletions flax/nnx/nnx/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,19 @@ class MultiMetric(Metric):
accuracy=Accuracy(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
value=Array(0., dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
value=Array(0, dtype=int32)
)
),
loss=Average(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
value=Array(0., dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
value=Array(0, dtype=int32)
)
)
)
Expand All @@ -309,21 +309,21 @@ class MultiMetric(Metric):
Accuracy(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
value=Array(0., dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
value=Array(0, dtype=int32)
)
)
>>> metrics.loss
Average(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
value=Array(0., dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
value=Array(0, dtype=int32)
)
)
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nnx/transforms/parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def vmap(
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F:
Expand Down
22 changes: 12 additions & 10 deletions flax/nnx/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,23 +490,26 @@ def grad_fn(*args):
has_aux: bool
diff_args: list[int]
ctx = graph.current_update_context('grad')
*_args, f, graphdef, non_diff_state, has_aux, diff_args = args
*args, f, graphdef, non_diff_state, has_aux, diff_args = args

# rebuild diff_state from substates in args
diff_state = State({})
for i in diff_args:
diff_state[i] = _args[i]
diff_state = State({0: diff_state.raw_mapping})
diff_state[i] = args[i]
diff_state: graph.GraphState = State({0: diff_state.raw_mapping})

diff_graph_nodes, input_nodes = ctx.merge(
graphdef, diff_state, non_diff_state
)

# add nodes to the args
for i, arg in diff_graph_nodes.items():
_args[i] = arg
args[i] = arg

out = f(*_args)
# add other nodes to the args
args = graph.insert_graph_nodes(args, input_nodes)

out = f(*args)

out, out_nodes = graph.extract_graph_nodes(out)

Expand Down Expand Up @@ -535,14 +538,13 @@ def _grad_general(
def grad_wrapper(*args):
ctx: graph.UpdateContext = graph.current_update_context('grad')
_argnums = _normalize_sequence(argnums)
_, input_nodes = graph.extract_graph_nodes(args)

_args = list(args)
diff_graph_nodes: dict[int, tp.Any] = {
i: arg
for i, arg in enumerate(args)
if i in _argnums and graph.is_node(arg)
}
args, input_nodes = graph.extract_graph_nodes(args)
args = list(args)

def only_diff(path: tuple, value: tp.Any) -> bool:
# diff_graph_nodes is the first element in the tuple
Expand All @@ -557,7 +559,7 @@ def only_diff(path: tuple, value: tp.Any) -> bool:
if 0 in diff_state:
for i, diff_substate in diff_state[0].items(): # type: ignore
assert isinstance(i, int)
_args[i] = diff_substate
args[i] = diff_substate
diff_args.append(i)
transform = jax.value_and_grad if return_value else jax.grad

Expand All @@ -570,7 +572,7 @@ def only_diff(path: tuple, value: tp.Any) -> bool:
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*_args, f, graphdef, non_diff_state, has_aux, diff_args)
)(*args, f, graphdef, non_diff_state, has_aux, diff_args)

if return_value:
if has_aux:
Expand Down

0 comments on commit 54fe15f

Please sign in to comment.