Skip to content

Commit

Permalink
[Pallas] Add better reduction support.
Browse files Browse the repository at this point in the history
Adds lowering rules for reduce_all, reduce_any, reduce_min, and reductions to scalars.

PiperOrigin-RevId: 648512839
  • Loading branch information
justinjfu authored and jax authors committed Jul 2, 2024
1 parent d9bd358 commit e094acf
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 91 deletions.
179 changes: 122 additions & 57 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,79 +1114,144 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule


def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
(x_aval,) = ctx.avals_in
if not ctx.avals_out[0].shape:
raise NotImplementedError(
"Cannot lower reductions to scalar. Reduce to one element vector"
" instead, using keepdims=True."
)

out_type = aval_to_ir_type(ctx.avals_out[0])
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = vector.CombiningKind.MAXIMUMF
val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf"))
def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
def _lowering_rule(ctx: LoweringRuleContext, x, *, axes):
(x_aval,) = ctx.avals_in
if not ctx.avals_out[0].shape:
# If reducing to a scalar, we reduce sequentially by reducing along
# each dimension individually and squeezing at the end. This avoids the
# materialization of any scalar-shaped tensors by the reduction op
# itself which is not supported.
def _proxy_fun(val, *, axes):
# Mosaic compilation errors when attempting to reduce over all axes in
# order. However, it works when we reduce backwards from the last axis.
for ax in sorted(axes)[::-1]:
val = reduce_fn(val, axis=ax, keepdims=True)
# Squeeze lowers to vector.ExtractOp which will place the final
# value in a scalar register.
return jnp.squeeze(val)
proxy_lowering = lower_fun(
_proxy_fun, multiple_results=False)
return proxy_lowering(ctx, x, axes=axes)

if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = type_to_kind[jnp.floating]
val = type_to_identity[jnp.floating]
val = ir.FloatAttr.get(ir.F32Type.get(), val)
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
raise NotImplementedError("Reductions over integers not implemented.")
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
raise NotImplementedError("Reductions over integers not implemented.")
else:
raise NotImplementedError(
f"Reductions over {x_aval.dtype} not implemented.")
out_type = aval_to_ir_type(ctx.avals_out[0])
identity = ir.DenseElementsAttr.get_splat(out_type, val)
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = ir.Attribute.parse("#vector.kind<maxsi>")
raise NotImplementedError
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = ir.Attribute.parse("#vector.kind<maxui>")
raise NotImplementedError
acc = arith.ConstantOp(out_type, identity)
op = vector.MultiDimReductionOp(
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
)
return op.result
acc = arith.ConstantOp(out_type, identity)
op = vector.MultiDimReductionOp(
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
)
return op.result
return _lowering_rule


REDUCE_MAX_KINDS = {
jnp.floating: vector.CombiningKind.MAXIMUMF,
jnp.signedinteger: vector.CombiningKind.MAXSI,
jnp.unsignedinteger: vector.CombiningKind.MAXUI,
}
REDUCE_MAX_IDENTITY = {
jnp.floating: float("-inf"),
jnp.signedinteger: np.iinfo(np.int32).min,
}
_reduce_max_lowering_rule = reduce_lowering_rule(
jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY)
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule


def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
(x_aval,) = ctx.avals_in
if not ctx.avals_out[0].shape:
raise NotImplementedError(
"Cannot lower reductions to scalar. Reduce to one element vector"
" instead, using keepdims=True."
)

out_type = aval_to_ir_type(ctx.avals_out[0])
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = ir.Attribute.parse("#vector.kind<add>")
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
identity = ir.DenseElementsAttr.get_splat(out_type, val)
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = ir.Attribute.parse("#vector.kind<add>")
raise NotImplementedError
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = ir.Attribute.parse("#vector.kind<add>")
raise NotImplementedError
acc = arith.ConstantOp(out_type, identity)
op = vector.MultiDimReductionOp(
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
)
return op.result
REDUCE_MIN_KINDS = {
jnp.floating: vector.CombiningKind.MINIMUMF,
jnp.signedinteger: vector.CombiningKind.MINSI,
jnp.unsignedinteger: vector.CombiningKind.MINUI,
}
REDUCE_MIN_IDENTITY = {
jnp.floating: float("inf"),
jnp.signedinteger: np.iinfo(np.int32).max,
}
_reduce_min_lowering_rule = reduce_lowering_rule(
jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY)
lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule


REDUCE_SUM_KINDS = {
jnp.floating: vector.CombiningKind.ADD,
jnp.signedinteger: vector.CombiningKind.ADD,
jnp.unsignedinteger: vector.CombiningKind.ADD,
}
REDUCE_SUM_IDENTITY = {
jnp.floating: 0.0,
jnp.signedinteger: 0,
}
_reduce_sum_lowering_rule = reduce_lowering_rule(
jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY)
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule


def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
def _proxy_reduce(arg, *, axes):
# Mosaic currently only supports float reductions, so we cast the boolean
# arg to a float and use reduce_min to implement reduce_and.
# TODO(justinfu): Implement this logic in Mosaic MultiDimReductionOp
# instead.
float_arg = jnp.where(arg, 1.0, 0.0)
return jnp.min(float_arg, axis=axes) > 0.0
proxy_lowering = lower_fun(
_proxy_reduce, multiple_results=False)
return proxy_lowering(ctx, x, axes=axes)

lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule


def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
def _proxy_reduce(arg, *, axes):
# Mosaic currently only supports float reductions, so we cast the boolean
# arg to a float and use reduce_max to implement reduce_or.
# TODO(justinfu): Implement this logic in Mosaic MultiDimReductionOp
# instead.
float_arg = jnp.where(arg, 1.0, 0.0)
return jnp.max(float_arg, axis=axes) > 0.0
proxy_lowering = lower_fun(
_proxy_reduce, multiple_results=False)
return proxy_lowering(ctx, x, axes=axes)

lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule


def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions
):
(aval_in,) = ctx.avals_in
(aval_out,) = ctx.avals_out

if jnp.issubdtype(aval_in.dtype, jnp.bool_):
# Direct broadcasts for bools are not supported in Mosaic due to booleans
# living in mask registers and broadcast operating on vregs. Broadcast as an
# integer instead and cast back to a bool.
# TODO(justinfu): Implement this logic in Mosaic BroadcastOp instead.
def _proxy_fun(val, *, shape, broadcast_dimensions):
int_val = jnp.where(val, 1, 0)
bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions)
return bcast_val == 1
proxy_lowering = lower_fun(
_proxy_fun, multiple_results=False)
return proxy_lowering(
ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions)

if broadcast_dimensions:
out_shape_list = [1] * len(shape)
for i, s in zip(broadcast_dimensions, aval_in.shape):
Expand Down
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class TPU_Op<string mnemonic, list<Trait> traits = []> :

def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [
I32EnumAttrCase<"SUM", 0, "sum">,
I32EnumAttrCase<"MAX", 1, "max">
I32EnumAttrCase<"MAX", 1, "max">,
I32EnumAttrCase<"MIN", 2, "min">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
Expand Down
12 changes: 12 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3609,6 +3609,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
builder.getF32Type(),
APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/true));
} break;
case vector::CombiningKind::MINIMUMF: {
neutral = builder.getFloatAttr(
builder.getF32Type(),
APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/false));
} break;
default:
return multi_reduction_op.emitOpError(
"Not implemented: unsupported kind");
Expand Down Expand Up @@ -3701,6 +3706,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
case vector::CombiningKind::MAXIMUMF:
tpu_kind = tpu::ReductionKind::MAX;
break;
case vector::CombiningKind::MINIMUMF:
tpu_kind = tpu::ReductionKind::MIN;
break;
default:
return multi_reduction_op.emitOpError(
"Not implemented: unsupported reduction kind");
Expand Down Expand Up @@ -3757,6 +3765,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
acc_vreg = builder.create<arith::MaximumFOp>(
vreg.getLoc(), *acc_vreg, vreg);
break;
case tpu::ReductionKind::MIN:
acc_vreg = builder.create<arith::MinimumFOp>(
vreg.getLoc(), *acc_vreg, vreg);
break;
}
}
return absl::OkStatus();
Expand Down
72 changes: 39 additions & 33 deletions tests/pallas/tpu/pallas_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,47 +2059,53 @@ def outer_body(carry):

class PallasCallReductionTest(PallasTPUTest):

def setUp(self):
if jtu.device_under_test() != 'tpu':
self.skipTest('Test only works on TPU')

super().setUp()

def test_integer_sum(self):
def kernel(x_ref, o_ref):
x = x_ref[:]
# We'd prefer to say:
# o_ref[0, 0] = jnp.sum(x)
# But this currently hits issues in both Pallas and Mosaic lowering.
r = jnp.sum(x, keepdims=True, axis=1)
r = jnp.sum(r, keepdims=True, axis=0)
o_ref[0, 0] = r[0, 0]
@parameterized.named_parameters(
('reduce_all_true', 'all_true', jnp.all, True),
('reduce_all_false', 'all_false', jnp.all, False),
('reduce_all_mixed', 'one_false', jnp.all, False),
('reduce_any_true', 'all_true', jnp.any, True),
('reduce_any_false', 'all_false', jnp.any, False),
('reduce_any_mixed', 'one_false', jnp.any, True),
)
def test_reduce_boolean(self, input_type, reduction_op, expected_result):
def kernel(x_ref, ones_ref, o_ref):
# Convert float to bool with a comparison.
bool_x = x_ref[...] == ones_ref[...]
reduced_as_bool = reduction_op(bool_x, keepdims=True)
# Convert bool to float with a select.
float_value = jnp.where(reduced_as_bool, 1.0, 0.0)
o_ref[0, 0] = float_value[0, 0]

if input_type == 'all_true':
x = jnp.ones((8, 128), dtype=jnp.float32)
elif input_type == 'all_false':
x = jnp.zeros((8, 128), dtype=jnp.float32)
elif input_type == 'one_false':
x = jnp.ones((8, 128), dtype=jnp.float32)
x = x.at[0, 0].set(0.0)
ones = jnp.ones_like(x)

x = jnp.full([8, 128], 2.0)
result = pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
pl.BlockSpec(lambda *_: (0, 0), (8, 128)),
pl.BlockSpec(lambda *_: (0, 0), (8, 128)),
],
out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM),
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32),
grid=(1,),
)(x)

np.testing.assert_array_equal(result[0, 0], 2048.0)
)(x, ones)
np.testing.assert_array_equal(result[0, 0], float(expected_result))

def test_integer_max(self):
@parameterized.named_parameters(
('sum', jnp.sum,), ('max', jnp.max,), ('min', jnp.min,)
)
def test_reduce_float(self, reduction_op):
def kernel(x_ref, o_ref):
x = x_ref[:]
# We'd prefer to say:
# o_ref[0, 0] = jnp.max(x)
# But this currently hits issues in both Pallas and Mosaic lowering.
x = jnp.max(x, keepdims=True, axis=1)
x = jnp.max(x, keepdims=True, axis=0)
o_ref[0, 0] = x[0, 0]

x = jnp.arange(1024.0)
x = jnp.reshape(x, [8, 128])
r = reduction_op(x_ref[...], keepdims=True)
o_ref[0, 0] = r[0, 0]

x = jax.random.normal(jax.random.key(0), (8, 128))
result = pl.pallas_call(
kernel,
in_specs=[
Expand All @@ -2110,7 +2116,7 @@ def kernel(x_ref, o_ref):
grid=(1,),
)(x)

np.testing.assert_array_equal(result[0, 0], 1023.0)
np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5)


class PallasCallDynamicDMATest(PallasTPUTest):
Expand Down

0 comments on commit e094acf

Please sign in to comment.