Skip to content

Commit

Permalink
[pallas] Split out pallas_vmap_test.py
Browse files Browse the repository at this point in the history
A couple of the VMAP tests are very slow, and they seem
good candidates for splitting out of pallas_test.py, which
is becoming very large anyway.

PiperOrigin-RevId: 649430760
  • Loading branch information
gnecula authored and jax authors committed Jul 4, 2024
1 parent 9214ace commit 5ad4790
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 176 deletions.
32 changes: 32 additions & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,38 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

jax_test(
name = "pallas_vmap_test",
srcs = [
"pallas_vmap_test.py",
],
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
},
},
disable_configs = [
"cpu", # The 64-bit variant
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
],
deps = [
"//jax:pallas",
"//jax:pallas_gpu",
"//jax:pallas_gpu_ops",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy"),
)

jax_test(
name = "mosaic_gpu_test",
srcs = [
Expand Down
176 changes: 0 additions & 176 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,182 +1116,6 @@ def setUp(self):
self.skipTest("On CPU the test works only in 32-bit mode")


class PallasCallVmapTest(PallasTest):

def setUp(self):
super().setUp()
if jtu.test_device_matches(["tpu"]):
# TODO: most tests fail on TPU in non-interpreter mode
self.skipTest("On TPU the test works only in interpret mode")

def test_vmap_of_simple_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
)
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1
out = jax.vmap(add_one)(jnp.arange(8))
out_ref = jnp.arange(1, 9)
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_simple_kernel_with_in_axes_None(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
)
def add(x_ref, y_ref, o_ref):
o_ref[()] = x_ref[()] + y_ref[()]
out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1)
out_ref = jnp.arange(1, 9)
np.testing.assert_allclose(out, out_ref)

def test_double_vmap_of_simple_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
)
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1
out = jax.vmap(jax.vmap(add_one))(jnp.arange(8).reshape((4, 2)))
out_ref = jnp.arange(1, 9).reshape((4, 2))
np.testing.assert_allclose(out, out_ref)

def test_quadruple_vmap_of_simple_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
)
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1
out = jax.vmap(jax.vmap(jax.vmap(jax.vmap(add_one))))(
jnp.arange(15 * 8).reshape((5, 3, 4, 2)))
out_ref = jnp.arange(1, 15 * 8 + 1).reshape((5, 3, 4, 2))
np.testing.assert_allclose(out, out_ref)

def test_quadruple_vmap_of_batched_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32),
grid=(7,))
def add_one(x_ref, o_ref):
i = pl.program_id(0)
o_ref[i] = x_ref[i] + 1
out = jax.vmap(jax.vmap(jax.vmap(jax.vmap(add_one))))(
jnp.arange(15 * 8 * 7).reshape((5, 3, 4, 2, 7)))
out_ref = jnp.arange(1, 15 * 8 * 7 + 1).reshape((5, 3, 4, 2, 7))
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_slicing_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))
def add_one(x_ref, o_ref):
i = pl.program_id(0)
o_ref[i] = x_ref[i] + 1
out = jax.vmap(add_one)(jnp.arange(8).reshape((4, 2)))
out_ref = jnp.arange(1, 9).reshape((4, 2))
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
input_output_aliases={1:0},
grid=())
def add(x_ref, _, o_ref):
o_ref[()] = x_ref[()] + o_ref[()] + 1
out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1)
out_ref = jnp.arange(2, 10)
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_kernel_with_input_output_aliases_different_axes(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.int32),
input_output_aliases={0: 0},
grid=(),
)
def add(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1

out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2)))
out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1)
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_slicing_kernel_different_axes(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))
def add_one(x_ref, o_ref):
i = pl.program_id(0)
o_ref[i] = x_ref[i] + 1
add_one_ref = lambda x: x + 1
x = jnp.arange(8).reshape((2, 4))

out = jax.vmap(add_one, in_axes=1, out_axes=1)(x)
out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=1)(x)
np.testing.assert_allclose(out, out_ref)

out = jax.vmap(add_one, in_axes=1, out_axes=0)(x)
out_ref = jax.vmap(add_one_ref, in_axes=1, out_axes=0)(x)
np.testing.assert_allclose(out, out_ref)

def test_double_vmap_of_slicing_kernel_different_axes(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
grid=(4,))
def sin(x_ref, o_ref):
i = pl.program_id(0)
o_ref[i] = jnp.sin(x_ref[i])
sin_ref = jnp.sin
x = jnp.arange(64.).reshape((8, 4, 2))

out = jax.vmap(jax.vmap(sin, in_axes=1), in_axes=0)(x)
out_ref = jax.vmap(jax.vmap(sin_ref, in_axes=1), in_axes=0)(x)
np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_small_large_vmap(self):
# Catches https://github.com/google/jax/issues/18361
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1

add_one = jax.vmap(jax.vmap(add_one))
add_one_ref = lambda x: x + 1

x = random.randint(random.key(0), (4, 65536, 2), 0, 10000)

out = add_one(x)
out_ref = add_one_ref(x)

np.testing.assert_allclose(out, out_ref)

def test_small_small_large_vmap(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1

add_one = jax.vmap(jax.vmap(jax.vmap(add_one)))
add_one_ref = lambda x: x + 1

x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000)

out = add_one(x)
out_ref = add_one_ref(x)

np.testing.assert_allclose(out, out_ref)


class PallasCallVmapInterpreterTest(PallasCallVmapTest):
INTERPRET = True

def setUp(self):
super().setUp()
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")


class PallasOpsTest(PallasTest):

def setUp(self):
Expand Down
Loading

0 comments on commit 5ad4790

Please sign in to comment.