Skip to content

Commit

Permalink
Merge branch 'google:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed May 29, 2024
2 parents 301bbc6 + 6c51234 commit 5a91ac3
Show file tree
Hide file tree
Showing 32 changed files with 610 additions and 231 deletions.
2 changes: 1 addition & 1 deletion docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pip install "jax[cpu]"
```
or, for NVIDIA GPU:
```
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U "jax[cuda12]"
```
For more detailed platform-specific installation information, check out {ref}`installation`.

Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ pytype_strict_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
deps = [
":config",
":hardware_utils",
":version",
],
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ class _IndexUpdateHelper:

class _IndexUpdateRef:
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ...
def set(self, values: Any,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ...
mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ...
def add(self, values: Any, indices_are_sorted: bool = False,
unique_indices: bool = False, mode: Optional[str] = None) -> Array: ...
def mul(self, values: Any, indices_are_sorted: bool = False,
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import os
from jax._src import hardware_utils
from jax import version
from jax._src import config
from jax._src import hardware_utils

running_in_cloud_tpu_vm: bool = False

Expand Down Expand Up @@ -73,3 +74,9 @@ def cloud_tpu_init() -> None:
# this makes tensorstore serialization work better on TPU
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')

if config.jax_pjrt_client_create_options.value is None:
config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
)
8 changes: 4 additions & 4 deletions jax/_src/compute_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def __init__(self):
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
'Nesting `compute_on` with different compute types is not supported'
f' yet. Current stack: {compute_on_context.stack}')
try:
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
'Nesting `compute_on` with different compute types is not supported'
f' yet. Current stack: {compute_on_context.stack}')
yield compute_on_context.stack[-1]
finally:
compute_on_context.stack.pop()
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,12 @@ def update_thread_local_jit_state(**kw):
'otherwise.'
))

jax_pjrt_client_create_options = define_optional_string_state(
name='jax_pjrt_client_create_options',
default=None,
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
'provided to a device platform pjrt client as extra arguments.'))

enable_checks = define_bool_state(
name='jax_enable_checks',
default=False,
Expand Down
59 changes: 35 additions & 24 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return type_cast(ir.DenseIntElementsAttr,
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))

def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
if hlo.get_api_version() < 5:
return dense_int_elements(xs)
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore

# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
if hlo.get_api_version() < 6:
return dense_int_elements(xs)
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore

def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
Expand All @@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
return ir.DenseElementsAttr.get(
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])

def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
if hlo.get_api_version() < 6:
return dense_bool_elements(xs)
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
return ir.DenseBoolArrayAttr.get(xs) # type: ignore

def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
Expand Down Expand Up @@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
_numpy_array_constant(collapsed_val)[0],
dense_int_array_v6(other_axes))
dense_int_array(other_axes))
return (out,)
else:
return _numpy_array_constant(val)
Expand Down Expand Up @@ -1694,7 +1682,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
assert kept_rules
# If there is a single rule left just apply the rule, without conditionals.
if len(kept_rules) == 1:
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
wrapped_out = map(wrap_singleton_ir_values, output)
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
util.flatten(wrapped_out))
return output

assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
Expand Down Expand Up @@ -1728,6 +1720,8 @@ def lower_per_platform(ctx: LoweringRuleContext,
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{description}, got output {output}") from e
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
util.flatten(out_nodes))
if inner_ctx.tokens_out is not None:
assert len(ordered_effects) == len(inner_ctx.tokens_out)
out_nodes = [inner_ctx.tokens_out.get(eff)
Expand Down Expand Up @@ -1866,6 +1860,21 @@ def core_call_lowering(ctx: LoweringRuleContext,
register_lowering(core.closed_call_p,
partial(core_call_lowering, name=None))

def map_compute_type(c_type):
if c_type == 'device_host':
return 'host'
elif c_type == 'device':
return 'dense'
raise ValueError('Invalid compute type received. Current supported values '
'are `device_host` and `device`')

def wrap_compute_type_in_place(ctx, op):
if ctx.compute_type is not None:
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
map_compute_type(ctx.compute_type))}
op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)


def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:
# broadcast_dimension[i] is the axis of the result where the axis i of
Expand All @@ -1882,17 +1891,19 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
else:
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
return hlo.dynamic_broadcast_in_dim(
out = hlo.dynamic_broadcast_in_dim(
aval_to_ir_type(aval_out), op,
shape,
dense_int_array_v6(broadcast_dimensions),
dense_int_array(broadcast_dimensions),
)
else:
assert all(d != ir.ShapedType.get_dynamic_size()
for d in aval_out.shape), aval_out # type: ignore
return hlo.broadcast_in_dim(
out = hlo.broadcast_in_dim(
aval_to_ir_type(aval_out), op,
dense_int_array_v6(broadcast_dimensions))
dense_int_array(broadcast_dimensions))
wrap_compute_type_in_place(ctx, out.owner)
return out

def multi_broadcast_in_dim(ctx: LoweringRuleContext,
ops: Sequence[ir.Value],
Expand Down Expand Up @@ -2725,10 +2736,10 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
rw = hlo.ReduceWindowOp(
list(map(aval_to_ir_type, out_avals)),
operands, init_values,
dense_int_array_v6(window_dimensions),
window_strides=dense_int_array_v6(window_strides),
base_dilations=dense_int_array_v6(base_dilation),
window_dilations=dense_int_array_v6(window_dilation),
dense_int_array(window_dimensions),
window_strides=dense_int_array(window_strides),
base_dilations=dense_int_array(base_dilation),
window_dilations=dense_int_array(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=[len(padding), 2]))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,10 @@ def _conv_general_dilated_lower(
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_array_v6(window_strides),
window_strides=mlir.dense_int_array(window_strides),
padding=mlir.dense_int_elements(padding),
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
lhs_dilation=mlir.dense_int_array(lhs_dilation),
rhs_dilation=mlir.dense_int_array(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision))
]
Expand All @@ -744,9 +744,9 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_array_v6(window_strides),
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
window_strides=mlir.dense_int_array(window_strides),
lhs_dilation=mlir.dense_int_array(lhs_dilation),
rhs_dilation=mlir.dense_int_array(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision))
]
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ def broadcast_hlo(
for aval, arg in zip(avals, args):
if aval.shape != aval_out.shape:
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
dims = mlir.dense_int_array_v6(
dims = mlir.dense_int_array(
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
if any(isinstance(d, ir.Value) for d in aval_out.shape):
arg = hlo.dynamic_broadcast_in_dim(
Expand Down Expand Up @@ -3963,7 +3963,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
operands, init_values = util.split_list(values, [len(values) // 2])
init_value_avals = ctx.avals_in[len(values) // 2:]
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands, init_values, mlir.dense_int_array_v6(dimensions))
operands, init_values, mlir.dense_int_array(dimensions))
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
with ir.InsertionPoint(reducer):
Expand Down Expand Up @@ -4174,7 +4174,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
dtype = aval_out.dtype
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
mlir.ir_constants(unit_factory(aval_out.dtype)),
mlir.dense_int_array_v6(axes))
mlir.dense_int_array(axes))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,10 @@ def source_to_front(group):
return [group[source]] + list(group[:source]) + list(group[source + 1:])
replica_groups = [source_to_front(group) for group in replica_groups]
channel = ctx.module_context.new_channel()
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
return hlo.CollectiveBroadcastOp(
x, replica_groups=_replica_groups_hlo(replica_groups)).results
x, replica_groups=_replica_groups_hlo(replica_groups),
channel_handle=channel_handle).results

pbroadcast_p = core.AxisPrimitive('pbroadcast')
pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
Expand Down Expand Up @@ -1271,7 +1273,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
x = hlo.broadcast_in_dim(
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
mlir.dense_int_array_v6(broadcast_dimensions))
mlir.dense_int_array(broadcast_dimensions))
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
if is_spmd:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,7 @@ def _gather_lower(ctx, operand, indices, *,
operand,
indices,
dnums,
mlir.dense_int_array_v6(slice_sizes),
mlir.dense_int_array(slice_sizes),
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]

mlir.register_lowering(gather_p, _gather_lower)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ def _select_and_scatter_lower(
operand,
source,
init_value,
window_dimensions=mlir.dense_int_array_v6(window_dimensions),
window_strides=mlir.dense_int_array_v6(window_strides),
window_dimensions=mlir.dense_int_array(window_dimensions),
window_strides=mlir.dense_int_array(window_strides),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
select = op.select.blocks.append(scalar_type, scalar_type)
Expand Down
Loading

0 comments on commit 5a91ac3

Please sign in to comment.