diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 80411f12a838..1f6721952112 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1114,79 +1114,146 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule -def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): - (x_aval,) = ctx.avals_in - if not ctx.avals_out[0].shape: - raise NotImplementedError( - "Cannot lower reductions to scalar. Reduce to one element vector" - " instead, using keepdims=True." - ) - - out_type = aval_to_ir_type(ctx.avals_out[0]) - if jnp.issubdtype(x_aval.dtype, jnp.floating): - kind = vector.CombiningKind.MAXIMUMF - val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf")) +def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity): + def _lowering_rule(ctx: LoweringRuleContext, x, *, axes): + (x_aval,) = ctx.avals_in + if not ctx.avals_out[0].shape: + # If reducing to a scalar, we reduce recursively by reducing along + # each dimension individually and squeezing. This avoids the + # materialization of any scalar-shaped tensors which would need + # to be placed into scalar registers explicitly. + def _proxy_fun(val, *, axes): + # Mosaic compilation errors when attempting to reduce over all axes in + # order. However, it works when we reduce backwards from the last axis. + for ax in sorted(axes)[::-1]: + val = reduce_fn(val, axis=ax, keepdims=True) + # Squeeze lowers to vector.ExtractOp which will place the final + # value in a scalar register. + return jnp.squeeze(val) + proxy_lowering = lower_fun( + _proxy_fun, multiple_results=False) + return proxy_lowering(ctx, x, axes=axes) + + if jnp.issubdtype(x_aval.dtype, jnp.floating): + kind = type_to_kind[jnp.floating] + val = type_to_identity[jnp.floating] + val = ir.FloatAttr.get(ir.F32Type.get(), val) + elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): + raise NotImplementedError("Reductions over integers not implemented.") + elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): + raise NotImplementedError("Reductions over integers not implemented.") + else: + raise NotImplementedError( + f"Reduction over {x_aval.dtype} not implemented.") + out_type = aval_to_ir_type(ctx.avals_out[0]) identity = ir.DenseElementsAttr.get_splat(out_type, val) - elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): - kind = ir.Attribute.parse("#vector.kind") - raise NotImplementedError - elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): - kind = ir.Attribute.parse("#vector.kind") - raise NotImplementedError - acc = arith.ConstantOp(out_type, identity) - op = vector.MultiDimReductionOp( - kind, - x, - acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes] - ), - ) - return op.result + acc = arith.ConstantOp(out_type, identity) + op = vector.MultiDimReductionOp( + kind, + x, + acc, + ir.ArrayAttr.get( + [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes] + ), + ) + return op.result + return _lowering_rule +REDUCE_MAX_KINDS = { + jnp.floating: vector.CombiningKind.MAXIMUMF, + jnp.signedinteger: vector.CombiningKind.MAXSI, + jnp.unsignedinteger: vector.CombiningKind.MAXUI, +} +REDUCE_MAX_IDENTITY = { + jnp.floating: float("-inf"), + jnp.signedinteger: np.iinfo(np.int32).min, +} +_reduce_max_lowering_rule = reduce_lowering_rule( + jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY) lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule -def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): - (x_aval,) = ctx.avals_in - if not ctx.avals_out[0].shape: - raise NotImplementedError( - "Cannot lower reductions to scalar. Reduce to one element vector" - " instead, using keepdims=True." - ) - - out_type = aval_to_ir_type(ctx.avals_out[0]) - if jnp.issubdtype(x_aval.dtype, jnp.floating): - kind = ir.Attribute.parse("#vector.kind") - val = ir.FloatAttr.get(ir.F32Type.get(), 0.0) - identity = ir.DenseElementsAttr.get_splat(out_type, val) - elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): - kind = ir.Attribute.parse("#vector.kind") - raise NotImplementedError - elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): - kind = ir.Attribute.parse("#vector.kind") - raise NotImplementedError - acc = arith.ConstantOp(out_type, identity) - op = vector.MultiDimReductionOp( - kind, - x, - acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes] - ), - ) - return op.result +REDUCE_MIN_KINDS = { + jnp.floating: vector.CombiningKind.MINIMUMF, + jnp.signedinteger: vector.CombiningKind.MINSI, + jnp.unsignedinteger: vector.CombiningKind.MINUI, +} +REDUCE_MIN_IDENTITY = { + jnp.floating: float("inf"), + jnp.signedinteger: np.iinfo(np.int32).max, +} +_reduce_min_lowering_rule = reduce_lowering_rule( + jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY) +lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule +REDUCE_SUM_KINDS = { + jnp.floating: vector.CombiningKind.ADD, + jnp.signedinteger: vector.CombiningKind.ADD, + jnp.unsignedinteger: vector.CombiningKind.ADD, +} +REDUCE_SUM_IDENTITY = { + jnp.floating: 0.0, + jnp.signedinteger: 0, +} +_reduce_sum_lowering_rule = reduce_lowering_rule( + jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY) lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule +def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes): + def _proxy_reduce(arg, *, axes): + # Mosaic currently only supports float reductions, so we cast the boolean + # arg to a float and use reduce_min to implement reduce_and. + # TODO(justinfu): Implement this logic in Mosaic MultiDimReductionOp + # instead. + float_arg = jnp.where(arg, 1.0, 0.0) + a = jnp.min(float_arg, axis=axes) + return a > 0.0 + proxy_lowering = lower_fun( + _proxy_reduce, multiple_results=False) + return proxy_lowering(ctx, x, axes=axes) + +lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule + + +def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes): + def _proxy_reduce(arg, *, axes): + # Mosaic currently only supports float reductions, so we cast the boolean + # arg to a float and use reduce_max to implement reduce_or. + # TODO(justinfu): Implement this logic in Mosaic MultiDimReductionOp + # instead. + float_arg = jnp.where(arg, 1.0, 0.0) + a = jnp.max(float_arg, axis=axes) + return a > 0.0 + proxy_lowering = lower_fun( + _proxy_reduce, multiple_results=False) + return proxy_lowering(ctx, x, axes=axes) + +lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule + + def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions ): (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out + + if jnp.issubdtype(aval_in.dtype, jnp.bool_): + # Direct broadcasts for bools are not supported in Mosaic due to booleans + # living in mask registers and broadcast operating on vregs. Broadcast as an + # integer instead and cast back to a bool. + # TODO(justinfu): Implement this logic in Mosaic BroadcastOp instead. + def _proxy_fun(val, *, shape, broadcast_dimensions): + int_val = jnp.where(val, 1, 0) + bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions) + return bcast_val == 1 + proxy_lowering = lower_fun( + _proxy_fun, multiple_results=False) + return proxy_lowering( + ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions) + if broadcast_dimensions: out_shape_list = [1] * len(shape) for i, s in zip(broadcast_dimensions, aval_in.shape): diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 93e4abe3e422..59c000dd1068 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -158,7 +158,8 @@ class TPU_Op traits = []> : def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ I32EnumAttrCase<"SUM", 0, "sum">, - I32EnumAttrCase<"MAX", 1, "max"> + I32EnumAttrCase<"MAX", 1, "max">, + I32EnumAttrCase<"MIN", 2, "min"> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::tpu"; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index f0c568aeea48..7e2dce8b9956 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3568,6 +3568,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, builder.getF32Type(), APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/true)); } break; + case vector::CombiningKind::MINIMUMF: { + neutral = builder.getFloatAttr( + builder.getF32Type(), + APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/false)); + } break; default: return multi_reduction_op.emitOpError( "Not implemented: unsupported kind"); @@ -3660,6 +3665,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, case vector::CombiningKind::MAXIMUMF: tpu_kind = tpu::ReductionKind::MAX; break; + case vector::CombiningKind::MINIMUMF: + tpu_kind = tpu::ReductionKind::MIN; + break; default: return multi_reduction_op.emitOpError( "Not implemented: unsupported reduction kind"); @@ -3716,6 +3724,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, acc_vreg = builder.create( vreg.getLoc(), *acc_vreg, vreg); break; + case tpu::ReductionKind::MIN: + acc_vreg = builder.create( + vreg.getLoc(), *acc_vreg, vreg); + break; } } return absl::OkStatus(); diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 92e5182061b2..e369cef5e1fb 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -2059,47 +2059,53 @@ def outer_body(carry): class PallasCallReductionTest(PallasTPUTest): - def setUp(self): - if jtu.device_under_test() != 'tpu': - self.skipTest('Test only works on TPU') - - super().setUp() - - def test_integer_sum(self): - def kernel(x_ref, o_ref): - x = x_ref[:] - # We'd prefer to say: - # o_ref[0, 0] = jnp.sum(x) - # But this currently hits issues in both Pallas and Mosaic lowering. - r = jnp.sum(x, keepdims=True, axis=1) - r = jnp.sum(r, keepdims=True, axis=0) - o_ref[0, 0] = r[0, 0] + @parameterized.named_parameters( + ('reduce_all_true', 'all_true', jnp.all, True), + ('reduce_all_false', 'all_false', jnp.all, False), + ('reduce_all_mixed', 'one_false', jnp.all, False), + ('reduce_any_true', 'all_true', jnp.any, True), + ('reduce_any_false', 'all_false', jnp.any, False), + ('reduce_any_mixed', 'one_false', jnp.any, True), + ) + def test_reduce_boolean(self, input_type, reduction_op, expected_result): + def kernel(x_ref, ones_ref, o_ref): + # Convert float to bool with a comparison. + bool_x = x_ref[...] == ones_ref[...] + reduced_as_bool = reduction_op(bool_x, keepdims=True) + # Convert bool to float with a select. + float_value = jnp.where(reduced_as_bool, 1.0, 0.0) + o_ref[0, 0] = float_value[0, 0] + + if input_type == 'all_true': + x = jnp.ones((8, 128), dtype=jnp.float32) + elif input_type == 'all_false': + x = jnp.zeros((8, 128), dtype=jnp.float32) + elif input_type == 'one_false': + x = jnp.ones((8, 128), dtype=jnp.float32) + x = x.at[0, 0].set(0.0) + ones = jnp.ones_like(x) - x = jnp.full([8, 128], 2.0) result = pl.pallas_call( kernel, in_specs=[ - pl.BlockSpec((8, 128), lambda *_: (0, 0)), + pl.BlockSpec(lambda *_: (0, 0), (8, 128)), + pl.BlockSpec(lambda *_: (0, 0), (8, 128)), ], - out_specs=pl.BlockSpec((1, 1), memory_space=pltpu.SMEM), + out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), grid=(1,), - )(x) - - np.testing.assert_array_equal(result[0, 0], 2048.0) + )(x, ones) + np.testing.assert_array_equal(result[0, 0], float(expected_result)) - def test_integer_max(self): + @parameterized.named_parameters( + ('sum', jnp.sum,), ('max', jnp.max,), ('min', jnp.min,) + ) + def test_reduce_float(self, reduction_op): def kernel(x_ref, o_ref): - x = x_ref[:] - # We'd prefer to say: - # o_ref[0, 0] = jnp.max(x) - # But this currently hits issues in both Pallas and Mosaic lowering. - x = jnp.max(x, keepdims=True, axis=1) - x = jnp.max(x, keepdims=True, axis=0) - o_ref[0, 0] = x[0, 0] - - x = jnp.arange(1024.0) - x = jnp.reshape(x, [8, 128]) + r = reduction_op(x_ref[...], keepdims=True) + o_ref[0, 0] = r[0, 0] + + x = jax.random.normal(jax.random.key(0), (8, 128)) result = pl.pallas_call( kernel, in_specs=[ @@ -2110,7 +2116,7 @@ def kernel(x_ref, o_ref): grid=(1,), )(x) - np.testing.assert_array_equal(result[0, 0], 1023.0) + np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5) class PallasCallDynamicDMATest(PallasTPUTest):