Skip to content

Commit

Permalink
Don't wrap singleton ir.Types during HLO lowering.
Browse files Browse the repository at this point in the history
This is similar to #22211, but for MLIR types instead of MLIR values.
  • Loading branch information
hawkinsp committed Jul 8, 2024
1 parent 9405d46 commit 3d5784a
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 118 deletions.
9 changes: 5 additions & 4 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,11 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated=differentiated, policy=policy,
is_gpu_platform=is_gpu_platform)

arg_types = map(mlir.aval_to_ir_types, ctx.avals_in)
arg_types = map(mlir.aval_to_ir_type, ctx.avals_in)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
jaxpr_args = mlir.unflatten_ir_values(barrier_op.results, map(len, arg_types))
jaxpr_args = mlir.unflatten_ir_values_like_types(
barrier_op.results, arg_types)
else:
jaxpr_args = args
outs, tokens_out = mlir.jaxpr_subcomp(
Expand All @@ -797,10 +798,10 @@ def _optimization_barrier_abstract_eval(*args):
return args

def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return mlir.unflatten_ir_values(barrier_op.results, map(len, barrier_types))
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)

def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def _wrap_main_func(
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value

def is_token(typ, attrs):
return (typ == mlir.token_type()[0])
return (typ == mlir.token_type())

orig_input_types = orig_main.type.inputs # type: ignore
arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
else:
current_platform_idx = cast(ir.Value, mlir.ir_constant(np.int32(0)))
# Compute the rule index based on the current platform
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
i32_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype=np.int32))
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
callee_platform_idx = hlo.CaseOp([i32_type],
Expand Down
Loading

0 comments on commit 3d5784a

Please sign in to comment.