Skip to content

Commit

Permalink
parameterized nnx_transforms tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 7, 2024
1 parent 3d98eda commit e8ef6fa
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions flax/experimental/nnx/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import typing as tp
from functools import partial
from absl.testing import parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -345,7 +346,7 @@ def constrain_object(m):



class TestGrad:
class TestGrad(parameterized.TestCase):
def test_grad(self):
p1 = nnx.Param(10.0)
p2 = nnx.Param(20.0)
Expand Down Expand Up @@ -451,34 +452,21 @@ def test_multiple_inputs(self):
assert 'bias' in grads
assert grads.bias.value.shape == (3,)

def test_multiple_graph_nodes(self):
@parameterized.parameters(
{'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1)},
{'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3)}
)
def test_multiple_graph_nodes(self, loss_fn, argnums):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param)
grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(m1, m2, x, y)

assert 'kernel' in grads_m1
assert grads_m1.kernel.value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.value.shape == (3,)

def test_multiple_graph_nodes_mix_positions(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(x, m1, y, m2)
inputs = [x, y]
inputs.insert(argnums[0], m1)
inputs.insert(argnums[1], m2)
grads_m1, grads_m2 = grad_fn(*inputs)

assert 'kernel' in grads_m1
assert grads_m1.kernel.value.shape == (2, 3)
Expand Down

0 comments on commit e8ef6fa

Please sign in to comment.