diff --git a/flax/nnx/tests/transforms_test.py b/flax/nnx/tests/transforms_test.py index 1d9c0b707b..93809deae4 100644 --- a/flax/nnx/tests/transforms_test.py +++ b/flax/nnx/tests/transforms_test.py @@ -15,6 +15,7 @@ import dataclasses import typing as tp from functools import partial +from absl.testing import parameterized import jax import jax.numpy as jnp @@ -345,7 +346,8 @@ def constrain_object(m): m.kernel.value.sharding -class TestGrad: + +class TestGrad(parameterized.TestCase): def test_grad(self): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) @@ -451,34 +453,21 @@ def test_multiple_inputs(self): assert 'bias' in grads assert grads.bias.value.shape == (3,) - def test_multiple_graph_nodes(self): - rngs = nnx.Rngs(0) - m1 = nnx.Linear(2, 3, rngs=rngs) - m2 = nnx.Linear(3, 3, rngs=rngs) - loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2) - grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param) - x = jax.random.uniform(rngs(), (1, 2)) - y = jnp.ones((1, 3)) - grads_m1, grads_m2 = grad_fn(m1, m2, x, y) - - assert 'kernel' in grads_m1 - assert grads_m1.kernel.value.shape == (2, 3) - assert 'bias' in grads_m1 - assert grads_m1.bias.value.shape == (3,) - assert 'kernel' in grads_m2 - assert grads_m2.kernel.value.shape == (3, 3) - assert 'bias' in grads_m2 - assert grads_m2.bias.value.shape == (3,) - - def test_multiple_graph_nodes_mix_positions(self): + @parameterized.parameters( + {'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1)}, + {'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3)} + ) + def test_multiple_graph_nodes(self, loss_fn, argnums): rngs = nnx.Rngs(0) m1 = nnx.Linear(2, 3, rngs=rngs) m2 = nnx.Linear(3, 3, rngs=rngs) - loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2) - grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param) + grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) - grads_m1, grads_m2 = grad_fn(x, m1, y, m2) + inputs = [x, y] + inputs.insert(argnums[0], m1) + inputs.insert(argnums[1], m2) + grads_m1, grads_m2 = grad_fn(*inputs) assert 'kernel' in grads_m1 assert grads_m1.kernel.value.shape == (2, 3)