Skip to content

Commit

Permalink
Merge pull request #22049 from mattjj:shard-map-in-spec-none
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647076188
  • Loading branch information
jax authors committed Jun 26, 2024
2 parents e0efcae + df90711 commit de474d1
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 13 deletions.
40 changes: 27 additions & 13 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -103,7 +103,8 @@ def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
mentioning a ``mesh`` axis name at a position expresses sharding the
corresponding argument array axis along that positional axis; not
mentioning an axis name expresses replication.
mentioning an axis name expresses replication. If an argument, or argument
subtree, has a corresponding spec of None, that argument is not sharded.
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
with a tree structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
Expand Down Expand Up @@ -153,13 +154,17 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
def wrapped(*args):
fun = lu.wrap_init(f)
args_flat, in_tree = tree_flatten(args)
try: in_specs_flat = broadcast_prefix(in_specs, args)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
try: in_specs_flat = broadcast_prefix(in_specs, args,
is_leaf=lambda x: x is None)
except ValueError:
e, *_ = prefix_errors(in_specs, args)
raise e('shard_map in_specs') from None
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
if s is not None)
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat)
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)

@memoize
def out_names_thunk():
Expand Down Expand Up @@ -258,21 +263,32 @@ class NoFail: pass

def _check_specs_vs_args(
f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs,
in_specs_flat: list[P], xs: list) -> None:
dyn_argnums: Sequence[int], in_specs_flat: Sequence[P],
xs: Sequence) -> None:
in_avals = map(shaped_abstractify, xs)
fail = [a if not len(p) <= a.ndim else no_fail
for p, a in zip(in_specs_flat, in_avals)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
raise ValueError(msg)

def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int],
fail: Sequence[core.ShapedArray | NoFail]
) -> list[core.ShapedArray | NoFail]:
fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves
for i, f in zip(dyn_argnums, fail):
fail_[i] = f
return fail_

def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: list[core.ShapedArray | NoFail]) -> str:
Expand Down Expand Up @@ -418,11 +434,11 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail]
failures = tree_unflatten(tree, fails)
failures_aug = generate_key_paths(failures)
specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs))
leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P
leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
return [((spec_key, spec), (fail_key, fail_data))
for (spec_key, spec), (fail_key, fail_data)
in zip(specs_aug, failures_aug) if fail_data is not no_fail]
return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data)
in zip(specs_aug, failures_aug)
if s is not None and fail_data is not no_fail]

# Primitive

Expand Down Expand Up @@ -502,9 +518,7 @@ def _shard_map_staging(
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(
f, main, in_avals_
)
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = map(_check_shapedarray, genavals)
_check_names(out_names_thunk(), out_avals_)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
Expand Down
95 changes: 95 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,101 @@ def g(x):
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
jax.vmap(g, spmd_axis_name='i')(xs)

def test_in_spec_none(self):
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))

x = jnp.arange(8).reshape(4, 2)

def f(o, x):
self.assertIs(o, obj)
return jnp.sin(x)

obj = object()
y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

obj = None
y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

def f2(o, x):
self.assertIsInstance(o, dict)
self.assertIs(o['a'], obj['a'])
return jnp.sin(x)

obj = {'a': object()}
y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

def f3(x, o):
self.assertIs(o, obj)
return jnp.sin(x)

obj = object()
y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

obj = None
y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

def f4(o1, o2, x, o3):
self.assertIs(o1, obj1)
self.assertIs(o2[0], obj2[0])
self.assertIs(o2[1], obj2[1])
self.assertIs(o3, obj3)
return jnp.sin(x)

obj1 = object()
obj2 = (object(), object())
obj3 = object()
y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3)
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)

def test_in_spec_none_divisibility_errors(self):
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
x = jnp.arange(4).reshape(2, 2)

with self.assertRaisesRegex(ValueError, 'divisible'):
shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x)

with self.assertRaisesRegex(ValueError, 'divisible'):
shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object())

with self.assertRaisesRegex(ValueError, 'divisible'):
shard_map(lambda *_: None, mesh, (P('i'), None), None
)(x, (object(), object()))

with self.assertRaisesRegex(ValueError, 'divisible'):
shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None,
)(x, (object(), object()))

with self.assertRaisesRegex(ValueError, 'divisible'):
shard_map(lambda *_: None, mesh, ((None, None), P('i')), None,
)((object(), object()), x)

def test_in_spec_none_rank_errors(self):
mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
x = jnp.arange(4)

with self.assertRaisesRegex(ValueError, 'rank'):
shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x)

with self.assertRaisesRegex(ValueError, 'rank'):
shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object())

with self.assertRaisesRegex(ValueError, 'rank'):
shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None
)(x, (object(), object()))

with self.assertRaisesRegex(ValueError, 'rank'):
shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None,
)(x, (object(), object()))

with self.assertRaisesRegex(ValueError, 'rank'):
shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None,
)((object(), object()), x)


class FunSpec(NamedTuple):
name: str
Expand Down

0 comments on commit de474d1

Please sign in to comment.