Skip to content

Commit

Permalink
Make pl.num_programs lowering take the vmapped axes into account
Browse files Browse the repository at this point in the history
Otherwise the size of the wrong axis is returned.

PiperOrigin-RevId: 614677218
  • Loading branch information
apaszke authored and jax authors committed Mar 11, 2024
1 parent de455e7 commit 71ec6e3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
28 changes: 22 additions & 6 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class MeshContext:
@dataclasses.dataclass
class LoweringContext:
ir_context: ir.Context
grid_indices: Sequence[ir.Value] | None
grid_rank: int # Includes both user and vmap axes.
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
user_grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
Expand Down Expand Up @@ -475,6 +477,8 @@ def body_func(*args):
mesh_context = None
lowering_context = LoweringContext(
ctx,
len(mosaic_grid_mapping.grid),
mosaic_grid_mapping.mapped_dims,
None,
arg_block_shapes,
source_info_util.NameStack(),
Expand Down Expand Up @@ -531,6 +535,8 @@ def body_func(*args):
mesh_context = None
lowering_context = LoweringContext(
ctx,
len(mosaic_grid_mapping.grid),
mosaic_grid_mapping.mapped_dims,
jaxpr_indices,
arg_block_shapes,
source_info_util.NameStack(),
Expand Down Expand Up @@ -1846,22 +1852,32 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):


def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
if ctx.lowering_context.grid_indices is None:
if ctx.lowering_context.user_grid_indices is None:
raise ValueError(
f"program id: {axis} was passed, but user did not provide a grid."
)
length = len(ctx.lowering_context.grid_indices)
length = len(ctx.lowering_context.user_grid_indices)
if not (0 <= axis < length):
raise ValueError(
f"user passed in program id with axis: {axis}, but grid only has"
f" length: {length}"
)
return ctx.lowering_context.grid_indices[axis]
return ctx.lowering_context.user_grid_indices[axis]
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule

def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
del ctx
return tpu.iteration_bound(axis)
mapped_axes = set(ctx.lowering_context.mapped_dims)
seen_user_axes = 0
for i in range(ctx.lowering_context.grid_rank):
seen_user_axes += int(i not in mapped_axes)
if seen_user_axes == axis + 1:
break
else:
raise ValueError(
f"user passed in program id with axis: {axis}, but grid only has"
f" length: {len(ctx.lowering_context.grid_rank)}"
)
return tpu.iteration_bound(i)
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule


Expand Down
21 changes: 21 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,27 @@ def dynamic_kernel(steps):

self.assertEqual(dynamic_kernel(4), 8)

@parameterized.parameters(range(1, 4))
def test_vmap_num_programs(self, num_vmaps):
result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32)

def kernel(y_ref):
y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0))

kernel_call = self.pallas_call(
kernel,
grid=(8,),
out_specs=pl.BlockSpec(lambda i: (0, 0), result_ty.shape),
out_shape=result_ty,
)

out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape)
f = kernel_call
for _ in range(num_vmaps):
f = lambda impl=f: jax.vmap(impl, axis_size=2)()
out = jax.jit(f)()
np.testing.assert_array_equal(out, np.full(out_shape, 8.0))

def test_num_programs_block_spec(self):
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
Expand Down

0 comments on commit 71ec6e3

Please sign in to comment.