Skip to content

Commit

Permalink
Merge pull request #20884 from superbobry:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636176070
  • Loading branch information
jax authors committed May 22, 2024
2 parents b06663d + 602d4bd commit 47420a3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
15 changes: 8 additions & 7 deletions jax/experimental/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,14 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
% (repr(closed_jaxpr.out_avals), repr(tiled_results))
)
axis_context = sharding_impls.SPMDAxisContext(mesh)
module = mlir.build_mlir_module_helper(
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
with core.extend_axis_env_nd(mesh.shape.items()):
module = mlir.build_mlir_module_helper(
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
result_sharding = _pack_result_sharding(result_shape, result_shardings)
return mlir.module_to_bytecode(module), arg_shardings, result_sharding

Expand Down
31 changes: 31 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,37 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x'))
self.assertArraysEqual(x, pjit_f(x))

@jtu.with_mesh([('x', 4)])
def test_custom_partitioner_with_scan(self):
self.skip_if_custom_partitioning_not_supported()

# This is a reproducer from https://github.com/google/jax/issues/20864.

@custom_partitioning
def f(x):
return jnp.sum(x)

def partition(mesh, arg_shapes, result_shape):
def lower_fn(xs):
def f(carry, x):
return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None

carry, _ = jax.lax.scan(f, 0, xs)
return carry

result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return mesh, lower_fn, result_shardings, arg_shardings

f.def_partition(
partition,
infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()),
propagate_user_sharding=lambda _, user_shape: user_shape.sharding)

pjit_f = pjit(f, in_shardings=P(None, 'x'))
xs = jnp.ones([32, 16])
self.assertEqual(pjit_f(xs), xs.sum())


@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 47420a3

Please sign in to comment.