Skip to content

Commit

Permalink
Read the layout set by with_sharding_constraint and set the top mod…
Browse files Browse the repository at this point in the history
…ule level `out_layout` to `AUTO` if wsc layout is not None.

This will allow XLA to override the entry_computation_layout with the layout set via custom call (i.e. via wsc).

PiperOrigin-RevId: 648911765
  • Loading branch information
yashk2810 authored and jax authors committed Jul 3, 2024
1 parent f089ecc commit 8844877
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
37 changes: 37 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,6 +2108,36 @@ def _default_rule(prim, num_outvars, *_, **__):
return tuple(safe_map(read, jaxpr.outvars))


@weakref_lru_cache
def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
) -> tuple[None | DeviceLocalLayout]:
from jax._src import pjit

env = {} # type: ignore
jaxpr = closed_jaxpr.jaxpr

def read(var):
if type(var) is core.Literal:
return None
return env[var]

def write(var, val):
env[var] = val

safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars))
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))

for eqn in jaxpr.eqns:
# TODO(yashkatariya): Replace this with a registration system when there are
# more primitives for layout propagation.
if eqn.primitive is pjit.sharding_constraint_p:
out_eqn_layouts = [eqn.params['layout']]
else:
out_eqn_layouts = [None] * len(eqn.outvars)
safe_map(write, eqn.outvars, out_eqn_layouts)
return tuple(safe_map(read, jaxpr.outvars))


MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]


Expand Down Expand Up @@ -2199,6 +2229,13 @@ def lower_sharding_computation(
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals

# If layout is propagated, then set the out_layout in the top module to AUTO
# so that XLA can override the entry_computation_layout. The propagated
# layout will be set via a custom call.
out_layouts_via_prop = get_out_layouts_via_propagation(closed_jaxpr)
out_layouts = tuple(DeviceLocalLayout.AUTO if p is not None else o
for o, p in safe_zip(out_layouts, out_layouts_via_prop))

assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))

Expand Down
8 changes: 3 additions & 5 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ def _split_layout_and_sharding(entries):
layouts, shardings = [], []

for e in entries_flat:
if e is None or is_unspecified_or_auto(e):
layouts.append(None)
shardings.append(e)
elif isinstance(e, Layout):
if isinstance(e, Layout):
layouts.append(e.device_local_layout)
shardings.append(e.sharding)
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
Expand Down Expand Up @@ -1430,7 +1427,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
for arg, jit_in_l, rs, aval in safe_zip(
args, jit_in_layouts, resolved_in_shardings, in_avals):
arg_layout, committed = (
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l,
rs, aval),
getattr(arg, '_committed', True))
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
Expand Down
15 changes: 5 additions & 10 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import math
from absl.testing import absltest
import numpy as np
from functools import partial

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -359,19 +358,15 @@ def test_make_array_from_callback(self):

def test_wsc_concrete_layout(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (128, 128)
shape = (16, 128)
s = NamedSharding(mesh, P('x'))
np_inp = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)

# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),))

# We need AUTO so that XLA can override the entry computation layout set.
# TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by
# default instead of `None` i.e. default layout and let the compiler choose
# the layout or try setting it to AUTO by default and see if there is chaos.
@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
@jax.jit
def f(x):
y = x.T
# Constrain `y` to the original layout of `arr` because without it,
Expand All @@ -383,17 +378,17 @@ def f(x):
self.assertEqual(out.layout, arr.layout)
self.assertArraysEqual(out, np_inp.T)

def test_wsc_concrete_layout_bfloat16(self):
def test_wsc_bfloat16_concrete_layout(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (128, 128)
shape = (16, 128)
s = NamedSharding(mesh, P('x'))
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
arr = jax.device_put(inp, s)

# Create a custom layout instead of using `arr.layout` to test the API.
custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1)))

@partial(jax.jit, out_shardings=Layout(DLL.AUTO))
@jax.jit
def f(x):
y = x.T
# Constrain `y` to the original layout of `arr` because without it,
Expand Down

0 comments on commit 8844877

Please sign in to comment.