Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parameterized NNX transforms tests #3906

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions flax/nnx/tests/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dataclasses
import typing as tp
from functools import partial
from absl.testing import parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -345,7 +346,8 @@ def constrain_object(m):
m.kernel.value.sharding


class TestGrad:

class TestGrad(parameterized.TestCase):
def test_grad(self):
p1 = nnx.Param(10.0)
p2 = nnx.Param(20.0)
Expand Down Expand Up @@ -451,34 +453,21 @@ def test_multiple_inputs(self):
assert 'bias' in grads
assert grads.bias.value.shape == (3,)

def test_multiple_graph_nodes(self):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(0, 1), wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(m1, m2, x, y)

assert 'kernel' in grads_m1
assert grads_m1.kernel.value.shape == (2, 3)
assert 'bias' in grads_m1
assert grads_m1.bias.value.shape == (3,)
assert 'kernel' in grads_m2
assert grads_m2.kernel.value.shape == (3, 3)
assert 'bias' in grads_m2
assert grads_m2.bias.value.shape == (3,)

def test_multiple_graph_nodes_mix_positions(self):
@parameterized.parameters(
{'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1)},
{'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3)}
)
def test_multiple_graph_nodes(self, loss_fn, argnums):
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(3, 3, rngs=rngs)
loss_fn = lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2)
grad_fn = nnx.grad(loss_fn, argnums=(1, 3), wrt=nnx.Param)
grad_fn = nnx.grad(loss_fn, argnums=argnums, wrt=nnx.Param)
x = jax.random.uniform(rngs(), (1, 2))
y = jnp.ones((1, 3))
grads_m1, grads_m2 = grad_fn(x, m1, y, m2)
inputs = [x, y]
inputs.insert(argnums[0], m1)
inputs.insert(argnums[1], m2)
grads_m1, grads_m2 = grad_fn(*inputs)

assert 'kernel' in grads_m1
assert grads_m1.kernel.value.shape == (2, 3)
Expand Down
Loading