From 818e7d92a4b88d17c73e8eb4d4b26d9e265a4e0d Mon Sep 17 00:00:00 2001 From: Seonghyeon Date: Tue, 28 May 2024 13:17:28 +0000 Subject: [PATCH 01/18] Fix rel_entr behavior at boundary value --- jax/_src/scipy/special.py | 2 +- tests/lax_scipy_special_functions_test.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 64c7133e327c..d148329b7b4e 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -672,7 +672,7 @@ def rel_entr( safe_q = jnp.where(both_gt_zero_mask, q, 1) log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)) result = jnp.where( - both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, jnp.inf) + both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf) ) return result diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f5fb042c6155..f950d1048df4 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -228,6 +228,15 @@ def testNdtriExtremeValues(self): self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) + def testRelEntrExtremeValues(self): + # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). + dtype = jax.numpy.zeros(0).dtype # default float dtype. + args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), + np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 2d23a66c6a8c0798f5566f952ce4563a87f559c8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 07:12:54 -0700 Subject: [PATCH 02/18] jnp.take_along_axis: support fill_value --- jax/_src/numpy/lax_numpy.py | 10 ++++++---- jax/numpy/__init__.pyi | 1 + tests/lax_numpy_test.py | 9 ++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9d4cdcd8865e..af4a1e4451a8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5725,12 +5725,13 @@ def _normalize_index(index, axis_size): @util.implements(np.take_along_axis, update_doc=False, lax_description=TAKE_ALONG_AXIS_DOC) -@partial(jit, static_argnames=('axis', 'mode')) +@partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, indices: ArrayLike, axis: int | None, mode: str | lax.GatherScatterMode | None = None, + fill_value: StaticScalar | None = None, ) -> Array: util.check_arraylike("take_along_axis", arr, indices) a = asarray(arr) @@ -5743,8 +5744,9 @@ def take_along_axis( if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}" raise ValueError(msg.format(idx_shape)) - return take_along_axis(a.ravel(), indices, 0) - rank = ndim(arr) + a = a.ravel() + axis = 0 + rank = a.ndim if rank != ndim(indices): msg = "indices and arr must have the same number of dimensions; {} vs. {}" raise ValueError(msg.format(ndim(indices), a.ndim)) @@ -5812,7 +5814,7 @@ def replace(tup, val): collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), - mode="fill" if mode is None else mode) + mode="fill" if mode is None else mode, fill_value=fill_value) ### Indexing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 1ee2d642e578..2cf251a36258 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -811,6 +811,7 @@ def take_along_axis( indices: ArrayLike, axis: Optional[int], mode: Optional[Union[str, GatherScatterMode]] = ..., + fill_value: Optional[StaticScalar] = None, ) -> Array: ... def tan(x: ArrayLike, /) -> Array: ... def tanh(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f46933c169c9..050f39cedd15 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4483,6 +4483,13 @@ def testTakeAlongAxisWithEmptyArgs(self): x = jnp.ones((4, 0, 3), dtype=jnp.int32) np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) + def testTakeAlongAxisOptionalArgs(self): + x = jnp.arange(5.0) + ind = jnp.array([0, 2, 4, 6]) + expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) + actual = jnp.take_along_axis(x, ind, axis=None, mode='fill', fill_value=10.0) + self.assertArraysEqual(expected, actual) + @jtu.sample_product( dtype=inexact_dtypes, shape=[0, 5], @@ -5973,7 +5980,7 @@ def testWrappedSignaturesMatch(self): 'clip': ['x', 'max', 'min'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], - 'take_along_axis': ['mode'], + 'take_along_axis': ['mode', 'fill_value'], 'fill_diagonal': ['inplace'], } From 3ade02326dbe35c219901d0f3b2d717fcf2204f4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 07:15:15 -0700 Subject: [PATCH 03/18] Fix type annotation for index-update --- jax/_src/basearray.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 737940512e56..6f37c16e6715 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -237,10 +237,10 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ... + mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[ArrayLike] = None) -> Array: ... + mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, From 6803d771d5969d288e4b4029d3295fc88f54e487 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 08:24:49 -0700 Subject: [PATCH 04/18] Improve docstrings for unravel_index & ravel_multi_index --- jax/_src/numpy/lax_numpy.py | 102 +++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9d4cdcd8865e..6c94f5b4b4f1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1043,9 +1043,62 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) -@util.implements(np.ravel_multi_index) def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: + """Convert multi-dimensional indices into flat indices. + + JAX implementation of :func:`numpy.ravel_multi_index` + + Args: + multi_index: sequence of integer arrays containing indices in each dimension. + dims: sequence of integer sizes; must have ``len(dims) == len(multi_index)`` + mode: how to handle out-of bound indices. Options are + + - ``"raise"`` (default): raise a ValueError. This mode is incompatible + with :func:`~jax.jit` or other JAX transformations. + - ``"clip"``: clip out-of-bound indices to valid range. + - ``"wrap"``: wrap out-of-bound indices to valid range. + + order: ``"C"`` (default) or ``"F"``, specify whether to assume C-style + row-major order or Fortran-style column-major order. + + Returns: + array of flattened indices + + See also: + :func:`jax.numpy.unravel_index`: inverse of this function. + + Example: + Define a 2-dimensional array and a sequence of indices of even values: + + >>> x = jnp.array([[2., 3., 4.], + ... [5., 6., 7.]]) + >>> indices = jnp.where(x % 2 == 0) + >>> indices + (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) + >>> x[indices] + Array([2., 4., 6.], dtype=float32) + + Compute the flattened indices: + + >>> indices_flat = jnp.ravel_multi_index(indices, x.shape) + >>> indices_flat + Array([0, 2, 4], dtype=int32) + + These flattened indices can be used to extract the same values from the + flattened ``x`` array: + + >>> x_flat = x.ravel() + >>> x_flat + Array([2., 3., 4., 5., 6., 7.], dtype=float32) + >>> x_flat[indices_flat] + Array([2., 4., 6.], dtype=float32) + + The original indices can be recovered with :func:`~jax.numpy.unravel_index`: + + >>> jnp.unravel_index(indices_flat, x.shape) + (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) + """ assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) util.check_arraylike("ravel_multi_index", *multi_index) @@ -1081,13 +1134,48 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result -_UNRAVEL_INDEX_DOC = """\ -Unlike numpy's implementation of unravel_index, negative indices are accepted -and out-of-bounds indices are clipped into the valid range. -""" - -@util.implements(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: + """Convert flat indices into multi-dimensional indices. + + JAX implementation of :func:`numpy.unravel_index`. The JAX version differs in + its treatment of out-of-bound indices: unlike NumPy, negative indices are + supported, and out-of-bound indices are clipped to the nearest valid value. + + Args: + indices: integer array of flat indices + shape: shape of multidimensional array to index into + + Returns: + Tuple of unraveled indices + + See also: + :func:`jax.numpy.ravel_multi_index`: Inverse of this function. + + Examples: + Start with a 1D array values and indices: + + >>> x = jnp.array([2., 3., 4., 5., 6., 7.]) + >>> indices = jnp.array([1, 3, 5]) + >>> print(x[indices]) + [3. 5. 7.] + + Now if ``x`` is reshaped, ``unravel_indices`` can be used to convert + the flat indices into a tuple of indices that access the same entries: + + >>> shape = (2, 3) + >>> x_2D = x.reshape(shape) + >>> indices_2D = jnp.unravel_index(indices, shape) + >>> indices_2D + (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) + >>> print(x_2D[indices_2D]) + [3. 5. 7.] + + The inverse function, ``ravel_multi_index``, can be used to obtain the + original indices: + + >>> jnp.ravel_multi_index(indices_2D, shape) + Array([1, 3, 5], dtype=int32) + """ util.check_arraylike("unravel_index", indices) indices_arr = asarray(indices) # Note: we do not convert shape to an array, because it may be passed as a From af6970e4320f8f192e72a9d6fc2690f6d26001a6 Mon Sep 17 00:00:00 2001 From: Chase Roberts Date: Mon, 27 May 2024 16:27:07 -0700 Subject: [PATCH 05/18] Pipe channel handle --- jax/_src/lax/parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 0d1ed7783208..558b27082a65 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -983,8 +983,10 @@ def source_to_front(group): return [group[source]] + list(group[:source]) + list(group[source + 1:]) replica_groups = [source_to_front(group) for group in replica_groups] channel = ctx.module_context.new_channel() + channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) return hlo.CollectiveBroadcastOp( - x, replica_groups=_replica_groups_hlo(replica_groups)).results + x, replica_groups=_replica_groups_hlo(replica_groups), + channel_handle=channel_handle).results pbroadcast_p = core.AxisPrimitive('pbroadcast') pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) From 822bc3647d9ef6ee783e602f9fcd63fbb7534907 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 28 May 2024 13:31:14 -0400 Subject: [PATCH 06/18] Update quickstart to suggested jax[cuda12] installation This PR updates the quickstart docs to match the README and Installation Guide suggestion of using ``` pip install jax[cuda12] ``` for GPU. --- docs/quickstart.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 8be061e90489..5c3562b8b2ea 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -31,7 +31,7 @@ pip install "jax[cpu]" ``` or, for NVIDIA GPU: ``` -pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install -U "jax[cuda12]" ``` For more detailed platform-specific installation information, check out {ref}`installation`. From 8b95853609e7c900d0ba17f7b78875a4f6ce3ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 28 May 2024 10:46:44 -0700 Subject: [PATCH 07/18] [Mosaic] Add relayout for (1, 128 * packing) -> (packing, 128). PiperOrigin-RevId: 637951690 --- jaxlib/mosaic/dialect/tpu/tpu.td | 17 ++- .../tpu/transforms/apply_vector_layout.cc | 100 ++++++++++++++++-- 2 files changed, 109 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 20643f089c9f..35d1cabe9fa6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -97,6 +97,18 @@ def TPU_ContractPrecisionEnum let assemblyFormat = "`<` $value `>`"; } +def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ + I32EnumAttrCase<"kCompressed", 0, "compressed">, + I32EnumAttrCase<"kInterleaved", 1, "interleaved"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PackFormatEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; @@ -278,7 +290,10 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { // Integer packs are always signed at the moment. def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> { - let arguments = (ins Variadic:$sources); + let arguments = (ins + Variadic:$sources, + TPU_PackFormatEnum:$pack_format + ); let results = (outs AnyVector:$output); let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 862c72ff52b0..d90733d7791c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -827,7 +827,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, ++idxs_local.back(); } } - *v = builder.create(res_vreg_ty, parts); + *v = builder.create(res_vreg_ty, parts, + tpu::PackFormat::kCompressed); }); } else if (layout_out.hasNativeTiling(ctx.target_shape)) { int packing = layout_out.packing(); @@ -848,7 +849,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, parts.push_back(parts.back()); } } - *v = builder.create(res_vreg_ty, parts); + *v = builder.create(res_vreg_ty, parts, + tpu::PackFormat::kCompressed); parts.clear(); }); } else { @@ -4576,10 +4578,8 @@ FailureOr> relayout( } else if ( // TODO(b/265133506): Generalize retiling. // (8,128) -> (8 * packing,128) tiling change for packed type. src.implicit_dim() == VectorLayout::ImplicitDim::kNone && - dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && - vty.getElementTypeBitWidth() < 32 && - 32 % vty.getElementTypeBitWidth() == 0 && - src.offsets() == dst.offsets() && + dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 && + 32 % bitwidth == 0 && src.offsets() == dst.offsets() && src.tiling() == std::array{8, 128} && dst.tiling() == std::array{8 * dst.packing(), 128}) { const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling()); @@ -4606,7 +4606,93 @@ FailureOr> relayout( } } *tile = builder.create( - v.getLoc(), src_tiles.begin()->getType(), parts); + v.getLoc(), src_tiles.begin()->getType(), parts, + tpu::PackFormat::kCompressed); + }); + src = new_src; + src_tiles = std::move(src_tiles_retiled); + } else if ( // Handle retiling from (1, 128 * packing) to (packing, 128) for + // packed data. + // We do compressed unpacking followed by interleaved packing. + // TODO(tlongeri): This can be used as a first step before using + // a generalized retiling where we only move sublanes around + // (without packing/unpacking). + // TODO(tlongeri): Interleaved unpacking followed by interleaved + // packing (but with different pairings) might also be + // interesting if the next step is a retile, since we can also + // match corresponding elements without shifting. It's just that + // the tiles are not adjacent (no contiguous vreg slice). + src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && bitwidth < 32 && + 32 % bitwidth == 0 && src.offsets() == dst.offsets() && + src.tiling() == std::array{1, 128 * packing} && + dst.tiling() == std::array{packing, 128}) { + // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of + // 4 sublanes and 2 lanes (this is convenient for to keep the example small + // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. + // + // The vreg slice is 1 x 16, that is, the vreg contains the data for a + // 1 x 16 window of the logical shape. + // + // [a b c d e f g h i j k l m n o p] -> vreg 1 + // [A B C D E F G H I J K L M N O P] -> vreg 2 + // + // Note: we support multiple vregs per row of the logical shape, but we use + // one here just to keep the example small. + // + // When we do a compressed unpack, the resulting vregs effectively have a + // tiling of (1, 2) and cover a vreg slice of 1 x 8 logical elements. + // + // [a b c d e f g h] -> vreg 1, part 1 [i j k l m n o p] -> vreg 1, part 2 + // [A B C D E F G H] -> vreg 2, part 1 [I J K L M N O P] -> vreg 2, part 2 + // + // It is clear that if combine vreg 1, part 1 and vreg 2, part 1 we get data + // that covers a 2 x 8 vreg slice. Note, however, that we will have to mind + // the internal ordering of the vreg. + // + // [a b c d e f g h [i j k l m n o p + // A B C D E F G H] -> new vreg 1 I J K L M N O P] -> new vreg 2 + // + // To see if we can get the right internal ordering that we need for (2, 2) + // tiling, let's break new vreg 1 into (1, 2) rows, which correspond to + // sublanes when unpacked and half-sublanes when packed. + // + // [(a b) (c d) (e f) (g h) + // (A B) (C D) (E F) (G H)] + // + // The sublane order for the vreg parts is [(a b) (c d) ...] for vreg 1, + // part 1 and [(A B) (C D) ...] for vreg 2, part 1. + // + // The desired half-sublane order, for packed (2, 2) tiling, is + // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before + // moving to the next one. This is exactly an interleaving of the sublanes + // of the vreg parts. + const VectorLayout new_src(src.bitwidth(), src.offsets(), + std::array{packing, 128}); + xla::Array src_tiles_retiled( + new_src.tileArrayShape(vty.getShape(), target_shape)); + const VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); + src_tiles_retiled.Each([&](absl::Span idx, Value *tile) { + SmallVector parts; + parts.reserve(packing); + SmallVector src_idx(toArrayRef(idx)); + *(src_idx.end() - 2) *= packing; + const int64_t vreg_part = *(src_idx.end() - 1) % packing; + *(src_idx.end() - 1) /= packing; + for (int i = 0; i < packing; ++i) { + parts.push_back(builder.create( + v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part)); + if (*(src_idx.end() - 2) < *(src_tiles.dimensions().end() - 2)) { + ++*(src_idx.end() - 2); + } // The rest is padding, so just pick any of the input parts (but not + // an arbitrary vreg so we don't add an extra dependency). + } + *tile = builder.create( + v.getLoc(), src_tiles.begin()->getType(), parts, + tpu::PackFormat::kInterleaved); }); src = new_src; src_tiles = std::move(src_tiles_retiled); From 43f51d73ceb154a281452a0f7189a7118de98fe5 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Tue, 28 May 2024 10:58:10 -0700 Subject: [PATCH 08/18] Clean up version switches from dense array migration PiperOrigin-RevId: 637955865 --- jax/_src/interpreters/mlir.py | 30 +++++++---------------- jax/_src/lax/convolution.py | 12 +++++----- jax/_src/lax/lax.py | 6 ++--- jax/_src/lax/parallel.py | 2 +- jax/_src/lax/slicing.py | 2 +- jax/_src/lax/windowed_reductions.py | 4 ++-- jax/experimental/export/_export.py | 37 +++++++++++++---------------- jax/interpreters/mlir.py | 1 - jaxlib/gpu_solver.py | 5 ++-- jaxlib/hlo_helpers.py | 11 +-------- 10 files changed, 42 insertions(+), 68 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 36d3c884b8be..9f333ea3897e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr: return type_cast(ir.DenseIntElementsAttr, ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))) -def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher - if hlo.get_api_version() < 5: - return dense_int_elements(xs) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore - -# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher -def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - if hlo.get_api_version() < 6: - return dense_int_elements(xs) +def dense_int_array(xs) -> ir.DenseI64ArrayAttr: return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: @@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: return ir.DenseElementsAttr.get( a, type=ir.IntegerType.get_signless(1), shape=[len(xs)]) -def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher - if hlo.get_api_version() < 6: - return dense_bool_elements(xs) +def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr: return ir.DenseBoolArrayAttr.get(xs) # type: ignore def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) @@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value ir.RankedTensorType.get( val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore _numpy_array_constant(collapsed_val)[0], - dense_int_array_v6(other_axes)) + dense_int_array(other_axes)) return (out,) else: return _numpy_array_constant(val) @@ -1885,14 +1873,14 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, return hlo.dynamic_broadcast_in_dim( aval_to_ir_type(aval_out), op, shape, - dense_int_array_v6(broadcast_dimensions), + dense_int_array(broadcast_dimensions), ) else: assert all(d != ir.ShapedType.get_dynamic_size() for d in aval_out.shape), aval_out # type: ignore return hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, - dense_int_array_v6(broadcast_dimensions)) + dense_int_array(broadcast_dimensions)) def multi_broadcast_in_dim(ctx: LoweringRuleContext, ops: Sequence[ir.Value], @@ -2725,10 +2713,10 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): rw = hlo.ReduceWindowOp( list(map(aval_to_ir_type, out_avals)), operands, init_values, - dense_int_array_v6(window_dimensions), - window_strides=dense_int_array_v6(window_strides), - base_dilations=dense_int_array_v6(base_dilation), - window_dilations=dense_int_array_v6(window_dilation), + dense_int_array(window_dimensions), + window_strides=dense_int_array(window_strides), + base_dilations=dense_int_array(base_dilation), + window_dilations=dense_int_array(window_dilation), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=[len(padding), 2])) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index f241188b8f19..2b2ad5bbb515 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -719,10 +719,10 @@ def _conv_general_dilated_lower( dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array_v6(window_strides), + window_strides=mlir.dense_int_array(window_strides), padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_array_v6(lhs_dilation), - rhs_dilation=mlir.dense_int_array_v6(rhs_dilation), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)) ] @@ -744,9 +744,9 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): dimension_numbers=dnums, feature_group_count=mlir.i64_attr(feature_group_count), batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array_v6(window_strides), - lhs_dilation=mlir.dense_int_array_v6(lhs_dilation), - rhs_dilation=mlir.dense_int_array_v6(rhs_dilation), + window_strides=mlir.dense_int_array(window_strides), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), window_reversal=window_reversal, precision_config=lax.precision_attr(precision)) ] diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 580da0b52a2b..1f621d685a15 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1760,7 +1760,7 @@ def broadcast_hlo( for aval, arg in zip(avals, args): if aval.shape != aval_out.shape: assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out) - dims = mlir.dense_int_array_v6( + dims = mlir.dense_int_array( range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))) if any(isinstance(d, ir.Value) for d in aval_out.shape): arg = hlo.dynamic_broadcast_in_dim( @@ -3963,7 +3963,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): operands, init_values = util.split_list(values, [len(values) // 2]) init_value_avals = ctx.avals_in[len(values) // 2:] op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands, init_values, mlir.dense_int_array_v6(dimensions)) + operands, init_values, mlir.dense_int_array(dimensions)) ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): @@ -4174,7 +4174,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): dtype = aval_out.dtype op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], mlir.ir_constants(unit_factory(aval_out.dtype)), - mlir.dense_int_array_v6(axes)) + mlir.dense_int_array(axes)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype)) reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_region): diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 0d1ed7783208..cd1dbd313d1a 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1271,7 +1271,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension] x = hlo.broadcast_in_dim( mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x, - mlir.dense_int_array_v6(broadcast_dimensions)) + mlir.dense_int_array(broadcast_dimensions)) replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) if is_spmd: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index fe8027e5111d..e58d5a7c7909 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1845,7 +1845,7 @@ def _gather_lower(ctx, operand, indices, *, operand, indices, dnums, - mlir.dense_int_array_v6(slice_sizes), + mlir.dense_int_array(slice_sizes), indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))] mlir.register_lowering(gather_p, _gather_lower) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 18db9c764903..8a3fbf2c37bb 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -665,8 +665,8 @@ def _select_and_scatter_lower( operand, source, init_value, - window_dimensions=mlir.dense_int_array_v6(window_dimensions), - window_strides=mlir.dense_int_array_v6(window_strides), + window_dimensions=mlir.dense_int_array(window_dimensions), + window_strides=mlir.dense_int_array(window_strides), padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64), shape=(len(padding), 2))) select = op.select.blocks.append(scalar_type, scalar_type) diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index c90a80b1cd16..1689bdc87608 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -514,26 +514,23 @@ def export_sharding(s: LoweringSharding, def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) - if hlo.get_api_version() < 4: - target_version = hlo.get_earliest_forward_compatible_version() - else: - # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. - # - # Each StableHLO version `producer_version` has a compatibility window, - # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], - # where StableHLO portable artifacts serialized by `producer_version` - # can be deserialized by `consumer_version` within the window. - # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md - # for the exact extent of these compatibility guarantees. - # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() + # `target_version` is used to manage situations when a StableHLO producer + # (in this case, jax2tf) and a StableHLO consumer were built using + # different versions of StableHLO. + # + # Each StableHLO version `producer_version` has a compatibility window, + # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], + # where StableHLO portable artifacts serialized by `producer_version` + # can be deserialized by `consumer_version` within the window. + # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md + # for the exact extent of these compatibility guarantees. + # + # `hlo.get_minimum_version()` returns `consumer_version_min` + # for the current version of StableHLO. We are using it here to maximize + # forward compatibility, i.e. to maximize how far into the past we can go + # and still have the payloads produced by `serialize_portable_artifact` + # compatible with potential consumers from the past. + target_version = hlo.get_minimum_version() module_serialized = xla_client._xla.mlir.serialize_portable_artifact( mlir_str, target_version) return module_serialized diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 7a72a478b807..ba476c75e519 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -37,7 +37,6 @@ dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, - dense_int_array_v6 as dense_int_array_v6, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, emit_python_callback as emit_python_callback, diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 4f316594765c..e804068e5e6d 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -28,7 +28,7 @@ from .hlo_helpers import ( DimensionSize, ShapeTypePair, mk_result_types_and_shapes, - custom_call, ensure_hlo_s32, hlo_s32, dense_int_array, dense_int_array_v6) + custom_call, ensure_hlo_s32, hlo_s32, dense_int_array) try: from .cuda import _blas as _cublas # pytype: disable=import-error @@ -536,14 +536,13 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower): # simply copy it back to where it needs to be: intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64)) - intarrattr_v6 = lambda xs: dense_int_array_v6(np.asarray(xs, np.int64)) if not lower and platform == "cu" and m > 1: start = (0,) * len(batch_dims) + (0,) end = batch_dims + (1,) s = hlo.slice( e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start))) s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type) - s = hlo.broadcast_in_dim(s_type, s, intarrattr_v6(range(len(dims) - 1))) + s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1))) # The diagonals are always real; convert to complex if needed. s = hlo.convert( ir.RankedTensorType.get(s_type.shape, a_type.element_type), s) diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index ee59a1b96ee1..4ec995172585 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -110,16 +110,7 @@ def hlo_s32(x: int): def ensure_hlo_s32(x: DimensionSize): return hlo_s32(x) if isinstance(x, int) else x -def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - # TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher - if hlo.get_api_version() < 5: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) - return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) - -# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher -def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr: - if hlo.get_api_version() < 6: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) +def dense_int_array(xs) -> ir.DenseI64ArrayAttr: return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: From 437263659cba0cd12e2d94c62f7395c156b5f6dd Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 28 May 2024 11:06:06 -0700 Subject: [PATCH 09/18] Fix pip dependency after cuDNN 9 upgrade PiperOrigin-RevId: 637959308 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f6861972c13a..e1cb3e2e38b0 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ _current_jaxlib_version = '0.4.28' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.28' -_default_cuda12_cudnn_version = '89' +_default_cuda12_cudnn_version = '91' _available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version] _libtpu_version = '0.1.dev20240508' From 9e80c30724d5613557334bb7e6366f765a73f1f1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 11:36:53 -0700 Subject: [PATCH 10/18] Improve documentation for jnp.take & jnp.take_along_axis --- jax/_src/numpy/lax_numpy.py | 175 +++++++++++++++++++++++++++++------- 1 file changed, 143 insertions(+), 32 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index af4a1e4451a8..d7b7cc2b7fe6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5608,27 +5608,6 @@ def unpackbits( return swapaxes(unpacked, axis, -1) -@util.implements(np.take, skip_params=['out'], - lax_description=""" -By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound -index semantics can be specified via the ``mode`` parameter (see below). -""", - extra_params=""" -mode : string, default="fill" - Out-of-bounds indexing mode. The default mode="fill" returns invalid values - (e.g. NaN) for out-of bounds indices (see also ``fill_value`` below). - For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`. -fill_value : optional - The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored - otherwise. Defaults to NaN for inexact types, the largest negative value for - signed types, the largest positive value for unsigned types, and True for booleans. -unique_indices : bool, default=False - If True, the implementation will assume that the indices are unique, - which can result in more efficient execution on some backends. -indices_are_sorted : bool, default=False - If True, the implementation will assume that the indices are sorted in - ascending order, which can lead to more efficient execution on some backends. -""") def take( a: ArrayLike, indices: ArrayLike, @@ -5639,6 +5618,78 @@ def take( indices_are_sorted: bool = False, fill_value: StaticScalar | None = None, ) -> Array: + """Take elements from an array. + + JAX implementation of :func:`numpy.take`, implemented in terms of + :func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case + of out-of-bound indices; see the ``mode`` parameter below. + + Args: + a: array from which to take values. + indices: N-dimensional array of integer indices of values to take from the array. + axis: the axis along which to take values. If not specified, the array will + be flattened before indexing is applied. + mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default + ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices; + the ``fill_value`` argument gives control over this value. For more discussion + of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. + fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'. + Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for + signed types, the largest positive value for unsigned types, and True for booleans. + unique_indices: If True, the implementation will assume that the indices are unique, + which can result in more efficient execution on some backends. If set to True and + indices are not unique, the output is undefined. + indices_are_sorted : If True, the implementation will assume that the indices are + sorted in ascending order, which can lead to more efficient execution on some + backends. If set to True and indices are not sorted, the output is undefined. + + Returns: + Array of values extracted from ``a``. + + See also: + - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. + - :func:`jax.numpy.take_along_axis`: take values along an axis + + Example: + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 6.]]) + >>> indices = jnp.array([2, 0]) + + Passing no axis results in indexing into the flattened array: + + >>> jnp.take(x, indices) + Array([3., 1.], dtype=float32) + >>> x.ravel()[indices] # equivalent indexing syntax + Array([3., 1.], dtype=float32) + + Passing an axis results ind applying the index to every subarray along the axis: + + >>> jnp.take(x, indices, axis=1) + Array([[3., 1.], + [6., 4.]], dtype=float32) + >>> x[:, indices] # equivalent indexing syntax + Array([[3., 1.], + [6., 4.]], dtype=float32) + + Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: + + >>> jnp.take(x, indices, axis=0) + Array([[nan, nan, nan], + [ 1., 2., 3.]], dtype=float32) + >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax + Array([[nan, nan, nan], + [ 1., 2., 3.]], dtype=float32) + + This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for + example, we can instead clip to the last valid value: + + >>> jnp.take(x, indices, axis=0, mode='clip') + Array([[4., 5., 6.], + [1., 2., 3.]], dtype=float32) + >>> x.at[indices].get(mode='clip') # equivalent indexing syntax + Array([[4., 5., 6.], + [1., 2., 3.]], dtype=float32) + """ return _take(a, indices, None if axis is None else operator.index(axis), out, mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value) @@ -5714,17 +5765,6 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) -TAKE_ALONG_AXIS_DOC = """ -Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes -an optional ``mode`` parameter controlling how out-of-bounds indices should be -handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``). -See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds -indexing in JAX. -""" - - -@util.implements(np.take_along_axis, update_doc=False, - lax_description=TAKE_ALONG_AXIS_DOC) @partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, @@ -5733,6 +5773,77 @@ def take_along_axis( mode: str | lax.GatherScatterMode | None = None, fill_value: StaticScalar | None = None, ) -> Array: + """Take elements from an array. + + JAX implementation of :func:`numpy.take_along_axis`, implemented in + terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy + in the case of out-of-bound indices; see the ``mode`` parameter below. + + Args: + a: array from which to take values. + indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional. + If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be + broadcast-compaible with ``indices`` along dimensions other than ``axis``. + axis: the axis along which to take values. If not specified, the array will + be flattened before indexing is applied. + mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default + ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices. + For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. + + Returns: + Array of values extracted from ``a``. + + See also: + - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. + - :func:`jax.numpy.take`: take the same indices along every axis slice. + + Examples: + >>> x = jnp.array([[1., 2., 3.], + ... [4., 5., 6.]]) + >>> indices = jnp.array([[0, 2], + ... [1, 0]]) + >>> jnp.take_along_axis(x, indices, axis=1) + Array([[1., 3.], + [5., 4.]], dtype=float32) + >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax + Array([[1., 3.], + [5., 4.]], dtype=float32) + + Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: + + >>> indices = jnp.array([[1, 0, 2]]) + >>> jnp.take_along_axis(x, indices, axis=0) + Array([[ 4., 2., nan]], dtype=float32) + >>> x.at[indices, jnp.arange(3)].get( + ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax + Array([[ 4., 2., nan]], dtype=float32) + + ``take_along_axis`` is helpful for extracting values from multi-dimensional + argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort` + indices along an axis, and use ``take_along_axis`` to construct the sorted + array: + + >>> x = jnp.array([[5, 3, 4], + ... [2, 7, 6]]) + >>> indices = jnp.argsort(x, axis=1) + >>> indices + Array([[1, 2, 0], + [0, 2, 1]], dtype=int32) + >>> jnp.take_along_axis(x, indices, axis=1) + Array([[3, 4, 5], + [2, 6, 7]], dtype=int32) + + Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and + use ``take_along_axis`` to extract the minimum value: + + >>> idx = jnp.argmin(x, axis=1, keepdims=True) + >>> idx + Array([[1], + [0]], dtype=int32) + >>> jnp.take_along_axis(x, idx, axis=1) + Array([[3], + [2]], dtype=int32) + """ util.check_arraylike("take_along_axis", arr, indices) a = asarray(arr) index_dtype = dtypes.dtype(indices) From b441a09a34d18040d5119c8eb29666fc35a646b4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 13:13:40 -0700 Subject: [PATCH 11/18] CI: remove stale warning filters --- pyproject.toml | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 752cd9bc08f4..c06c16029560 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,35 +58,16 @@ markers = [ ] filterwarnings = [ "error", - "ignore:The hookimpl.*:DeprecationWarning", - "ignore:No GPU/TPU found, falling back to CPU.:UserWarning", - "ignore:xmap is an experimental feature and probably has bugs!", - "ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning", - "ignore:can't resolve package from __spec__ or __package__:ImportWarning", - "ignore:Using or importing the ABCs.*:DeprecationWarning", - "ignore:numpy.ufunc size changed", - "ignore:.*experimental feature", - "ignore:The distutils.* is deprecated.*:DeprecationWarning", - "default:Error reading persistent compilation cache entry for 'jit_equal'", - "default:Error reading persistent compilation cache entry for 'jit__lambda_'", - "default:Error writing persistent compilation cache entry for 'jit_equal'", - "default:Error writing persistent compilation cache entry for 'jit__lambda_'", - "ignore:backend and device argument on jit is deprecated.*:DeprecationWarning", - # TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs - # https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7) - "ignore:ml_dtypes.float8_e4m3b11 is deprecated.", - "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", - "ignore:np.find_common_type is deprecated.*:DeprecationWarning", - "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", + "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", + "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", + "default:backend and device argument on jit is deprecated.*:DeprecationWarning", + "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize # start array_api_tests-related warnings - "ignore:The numpy.array_api submodule is still experimental.*:UserWarning", - "ignore:case not machine-readable.*:UserWarning", - "ignore:not machine-readable.*:UserWarning", - "ignore:Special cases found for .* but none were parsed.*:UserWarning", + "default:.*not machine-readable.*:UserWarning", + "default:Special cases found for .* but none were parsed.*:UserWarning", + "default:.*is not JSON-serializable. Using the repr instead.", # end array_api_tests-related warnings - "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", - "ignore:.*is not JSON-serializable. Using the repr instead.", ] doctest_optionflags = [ "NUMBER", From fcdbb2a29269dab83a2eb5f5e7662ab6dfc2e930 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 May 2024 13:24:36 -0700 Subject: [PATCH 12/18] Improve documentation for jnp.bincount --- jax/_src/numpy/lax_numpy.py | 68 ++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d7b7cc2b7fe6..137db6530f77 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1620,18 +1620,68 @@ def select( return lax.select_n(*broadcast_arrays(idx, *choicelist)) -@util.implements(np.bincount, lax_description="""\ -Jax adds the optional `length` parameter which specifies the output length, and -defaults to ``x.max() + 1``. It must be specified for bincount to be compiled -with non-static operands. Values larger than the specified length will be discarded. -If `length` is specified, `minlength` will be ignored. - -Additionally, while ``np.bincount`` raises an error if the input array contains -negative values, ``jax.numpy.bincount`` clips negative values to zero. -""") def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: + """Count the number of occurrences of each value in an integer array. + + JAX implementation of :func:`numpy.bincount`. + + For an array of positive integers ``x``, this function returns an array ``counts`` + of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences + of the value ``i`` in ``x``. + + The JAX version has a few differences from the NumPy version: + + - In NumPy, passing an array ``x`` with negative entries will result in an error. + In JAX, negative values are clipped to zero. + - JAX adds an optional ``length`` parameter which can be used to statically specify + the length of the output array so that this function can be used with transformations + like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped. + + Args: + x : N-dimensional array of positive integers + weights: optional array of weights associated with ``x``. If not specified, the + weight for each entry will be ``1``. + minlength: the minimum length of the output counts array. + length: the length of the output counts array. Must be specified statically for + ``bincount`` to be used with :func:`jax.jit` and other JAX transformations. + + Returns: + An array of counts or summed weights reflecting the number of occurrances of values + in ``x``. + + See Also: + - :func:`jax.numpy.histogram` + - :func:`jax.numpy.digitize` + - :func:`jax.numpy.unique_counts` + + Examples: + Basic bincount: + + >>> x = jnp.array([1, 1, 2, 3, 3, 3]) + >>> jnp.bincount(x) + Array([0, 2, 1, 3], dtype=int32) + + Weighted bincount: + + >>> weights = jnp.array([1, 2, 3, 4, 5, 6]) + >>> jnp.bincount(x, weights) + Array([ 0, 3, 3, 15], dtype=int32) + + Specifying a static ``length`` makes this jit-compatible: + + >>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) + >>> jit_bincount(x, length=5) + Array([0, 2, 1, 3, 0], dtype=int32) + + Any negative numbers are clipped to the first bin, and numbers beyond the + specified ``length`` are dropped: + + >>> x = jnp.array([-1, -1, 1, 3, 10]) + >>> jnp.bincount(x, length=5) + Array([2, 1, 0, 1, 0], dtype=int32) + """ util.check_arraylike("bincount", x) if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") From 91d68b55646eb2693d1244024e70cf1787bcb93e Mon Sep 17 00:00:00 2001 From: Yazhou Zu Date: Tue, 28 May 2024 13:42:18 -0700 Subject: [PATCH 13/18] creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework PiperOrigin-RevId: 638009713 --- jax/BUILD | 1 + jax/_src/cloud_tpu_init.py | 9 ++++++++- jax/_src/config.py | 6 ++++++ jax/_src/xla_bridge.py | 41 ++++++++++++++++++++++++++++---------- tests/xla_bridge_test.py | 17 ++++++++++++---- 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 9a98cf4fa69a..ea3ddcc76d1d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -389,6 +389,7 @@ pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], deps = [ + ":config", ":hardware_utils", ":version", ], diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 68fe25562655..71827d8f8a9f 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -13,8 +13,9 @@ # limitations under the License. import os -from jax._src import hardware_utils from jax import version +from jax._src import config +from jax._src import hardware_utils running_in_cloud_tpu_vm: bool = False @@ -73,3 +74,9 @@ def cloud_tpu_init() -> None: # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256') + + if config.jax_pjrt_client_create_options.value is None: + config.update( + 'jax_pjrt_client_create_options', + f'ml_framework_name:JAX;ml_framework_version:{version.__version__}' + ) diff --git a/jax/_src/config.py b/jax/_src/config.py index 8ae6dab0b03f..4ee6b16abba1 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -935,6 +935,12 @@ def update_thread_local_jit_state(**kw): 'otherwise.' )) +jax_pjrt_client_create_options = define_optional_string_state( + name='jax_pjrt_client_create_options', + default=None, + help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' + 'provided to a device platform pjrt client as extra arguments.')) + enable_checks = define_bool_state( name='jax_enable_checks', default=False, diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 55ec8455dbba..c68d39abf96f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -47,6 +47,7 @@ from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.lib import jaxlib logger = logging.getLogger(__name__) @@ -160,7 +161,13 @@ def _log_warning(): t.start() try: - client = xla_client.make_tpu_client(_get_tpu_library_path()) + if xla_extension_version >= 267: + client = xla_client.make_tpu_client( # type: ignore + _get_tpu_library_path(), + _options_from_jax_configs("tpu")) + else: + client = xla_client.make_tpu_client( + _get_tpu_library_path()) finally: t.cancel() @@ -618,16 +625,30 @@ def discover_pjrt_plugins() -> None: def _options_from_jax_configs(plugin_name): - if plugin_name != "cuda": - return {} - options = {} - visible_devices = CUDA_VISIBLE_DEVICES.value - if visible_devices != 'all': - options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value - if options['enable_mock_nccl']: - options['num_nodes'] = _MOCK_NUM_GPUS.value + + pjrt_client_options = config.jax_pjrt_client_create_options.value + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + + if plugin_name == "cuda": + visible_devices = CUDA_VISIBLE_DEVICES.value + if visible_devices != 'all': + options['visible_devices'] = [int(x) for x in visible_devices.split(',')] + options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value + if options['enable_mock_nccl']: + options['num_nodes'] = _MOCK_NUM_GPUS.value + return options diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index ecbf2e202fd0..d118d0e6454b 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -26,6 +26,7 @@ from jax._src import xla_bridge as xb from jax._src.interpreters import xla from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -143,7 +144,7 @@ def test_timer_tpu_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - def _mock_tpu_client(library_path=None): + def _mock_tpu_client_with_options(library_path=None, options=None): time_to_wait = 5 start = time.time() while not w: @@ -157,9 +158,17 @@ def _mock_tpu_client(library_path=None): msg = str(w[-1].message) self.assertIn("Did you run your code on all TPU hosts?", msg) - with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client): - xb.tpu_client_timer_callback(0.01) + def _mock_tpu_client(library_path=None): + _mock_tpu_client_with_options(library_path=library_path, options=None) + + if xla_extension_version >= 267: + with mock.patch.object(xc, "make_tpu_client", + side_effect=_mock_tpu_client_with_options): + xb.tpu_client_timer_callback(0.01) + else: + with mock.patch.object(xc, "make_tpu_client", + side_effect=_mock_tpu_client): + xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): with self.assertLogs(level="WARNING") as log_output: From 26b4848ff0db4a573357609c687da69b71027d7c Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Tue, 28 May 2024 14:20:39 -0700 Subject: [PATCH 14/18] Remove warning logs for primary_host/remote_storage "incompatibility". PiperOrigin-RevId: 638022129 --- jax/experimental/array_serialization/serialization.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 7ef1d6ad73a7..9d947527bb1c 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -216,14 +216,6 @@ async def async_serialize( f'between processes. Serialization have failed for the array with ' f'the path "{tensorstore_spec["kvstore"]["path"]}".') - if primary_host is None and is_remote_storage(tensorstore_spec): - # Not strictly an error because users may manually split directories into - # per-process subdirectories. - logging.warning( - 'When primary_host is set to None and remote storage is used,' - ' serialization is not allowed, as this may lead to a race condition' - ' between processes.' - ) # 'metadata' may not be present at the top level (for example, if we are using # a 'cast' driver). if not _spec_has_metadata(tensorstore_spec): From b47157dbe6bd24026e2d7b6a493041a04297f905 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 28 May 2024 15:19:36 -0700 Subject: [PATCH 15/18] Move throwing the error for invalid compute_on into the try-finally context. This should avoid leaking compute_on contexts. PiperOrigin-RevId: 638040571 --- jax/_src/compute_on.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 53ca50f2ef03..25b2be78d287 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -28,11 +28,11 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) - if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: - raise NotImplementedError( - 'Nesting `compute_on` with different compute types is not supported' - f' yet. Current stack: {compute_on_context.stack}') try: + if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: + raise NotImplementedError( + 'Nesting `compute_on` with different compute types is not supported' + f' yet. Current stack: {compute_on_context.stack}') yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() From 51e595185804ee257f1f223deac6aefb2d60f66e Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 28 May 2024 15:43:45 -0700 Subject: [PATCH 16/18] Use opaque layout PJRT_Layouts_MemoryLayout in PjRtCApiBuffer::layout() to keep all the layout information. PjRtCApiBuffer::layout() was using PJRT_Buffer_GetMemoryLayout, which will be deprecated. PJRT_Buffer_GetMemoryLayout uses explicit PJRT_Buffer_MemoryLayout which does not contain all the layout information. PiperOrigin-RevId: 638048293 --- jax/experimental/array_serialization/serialization_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index a35a8f326381..f0bf4f456292 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -461,9 +461,6 @@ def test_load_with_layout(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): - if xb.using_pjrt_c_api() and xb.get_backend().platform == "gpu": - self.skipTest('b/342255612') - dtype = jnp.int4 shape = (8, 2) arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) From 4fae9aa1603b67f03c295bd196258e7f0c4144e7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 28 May 2024 16:58:33 -0700 Subject: [PATCH 17/18] Support eager unified memory computations PiperOrigin-RevId: 638073121 --- jax/_src/interpreters/mlir.py | 29 +++++++++++++++++-- jax/_src/pjit.py | 13 +-------- tests/memories_test.py | 54 ++++++++++++++++++++--------------- 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 9f333ea3897e..f5f0a2df42c9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1682,7 +1682,11 @@ def lower_per_platform(ctx: LoweringRuleContext, assert kept_rules # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: - return kept_rules[0](ctx, *rule_args, **rule_kwargs) + output = kept_rules[0](ctx, *rule_args, **rule_kwargs) + wrapped_out = map(wrap_singleton_ir_values, output) + map(lambda o: wrap_compute_type_in_place(ctx, o.owner), + util.flatten(wrapped_out)) + return output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable" @@ -1716,6 +1720,8 @@ def lower_per_platform(ctx: LoweringRuleContext, except TypeError as e: raise ValueError("Output of translation rule must be iterable: " f"{description}, got output {output}") from e + map(lambda o: wrap_compute_type_in_place(ctx, o.owner), + util.flatten(out_nodes)) if inner_ctx.tokens_out is not None: assert len(ordered_effects) == len(inner_ctx.tokens_out) out_nodes = [inner_ctx.tokens_out.get(eff) @@ -1854,6 +1860,21 @@ def core_call_lowering(ctx: LoweringRuleContext, register_lowering(core.closed_call_p, partial(core_call_lowering, name=None)) +def map_compute_type(c_type): + if c_type == 'device_host': + return 'host' + elif c_type == 'device': + return 'dense' + raise ValueError('Invalid compute type received. Current supported values ' + 'are `device_host` and `device`') + +def wrap_compute_type_in_place(ctx, op): + if ctx.compute_type is not None: + dict_attr = {"_xla_compute_type": ir.StringAttr.get( + map_compute_type(ctx.compute_type))} + op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + + def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: # broadcast_dimension[i] is the axis of the result where the axis i of @@ -1870,7 +1891,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, else: if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore - return hlo.dynamic_broadcast_in_dim( + out = hlo.dynamic_broadcast_in_dim( aval_to_ir_type(aval_out), op, shape, dense_int_array(broadcast_dimensions), @@ -1878,9 +1899,11 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, else: assert all(d != ir.ShapedType.get_dynamic_size() for d in aval_out.shape), aval_out # type: ignore - return hlo.broadcast_in_dim( + out = hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, dense_int_array(broadcast_dimensions)) + wrap_compute_type_in_place(ctx, out.owner) + return out def multi_broadcast_in_dim(ctx: LoweringRuleContext, ops: Sequence[ir.Value], diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a79c7de0edb2..43bc77836613 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1712,14 +1712,6 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, mod_ctx.cached_primitive_lowerings[key] = func return func -def _map_compute_type(c_type): - if c_type == 'device_host': - return 'host' - elif c_type == 'device': - return 'dense' - raise ValueError('Invalid compute type received. Current supported values ' - 'are `device_host` and `device`') - def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, @@ -1739,10 +1731,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, call = func_dialect.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), mlir.flatten_lowering_ir_args(args)) - if ctx.compute_type is not None: - dict_attr = {"_xla_compute_type": ir.StringAttr.get( - _map_compute_type(ctx.compute_type))} - call.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + mlir.wrap_compute_type_in_place(ctx, call) out_nodes = unflatten(call.results, map(len, output_types)) tokens, out_nodes = split_list(out_nodes, [len(effects)]) tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) diff --git a/tests/memories_test.py b/tests/memories_test.py index b9d1bc4895c1..782a2ec59725 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -14,6 +14,7 @@ import functools import math +import re from absl.testing import absltest from absl.testing import parameterized from absl import flags @@ -630,8 +631,9 @@ def f(x): jtu.check_grads(jf, (inp,), order=2) - lowered_text = jf.lower(inp).as_text() - self.assertEqual(lowered_text.count('_xla_compute_type = "host"'), 2) + lowered_text = jf.lower(inp).as_text('hlo') + out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) + self.assertLen(out, 2) def test_compute_on_remat(self): inp = jnp.arange(16.) @@ -656,8 +658,9 @@ def f(x): jf = jax.jit(jax.grad(f)) jf(inp) # doesn't crash - lowered_text = jf.lower(inp).as_text() - self.assertEqual(lowered_text.count('_xla_compute_type = "host"'), 2) + lowered_text = jf.lower(inp).as_text('hlo') + out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text) + self.assertLen(out, 2) def test_nested_no_op_compute(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -784,25 +787,30 @@ def f(x): self.assertEqual(out.sharding.memory_kind, 'pinned_host') self.assertArraysEqual(out, np_inp * np_inp) - # def test_eager_compute(self): - # inp = jnp.arange(8) - # with compute_on('device_host'): - # a = inp * 2 - # print(a) - - # def test_compute_only_host(self): - # @compute_on('device_host') - # @jax.jit - # def f(x): - # return x * 2 - # f(jnp.arange(8)) - - # def test_per_annotation_wrapper(self): - # @jax.jit - # @compute_on('device_host') - # def f(x): - # return x * 2 - # f(jnp.arange(8)) + def test_eager_compute(self): + inp = jnp.arange(8.) + with compute_on('device_host'): + out = inp * 2 + out = jnp.sin(out) + self.assertArraysAllClose(out, jnp.sin(inp * 2)) + + def test_compute_per_annotation(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + s = NamedSharding(mesh, P("x", "y")) + np_inp = np.arange(16.).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + @jax.jit + @compute_on('device_host') + def f(x): + return jnp.sin(x * 2) + + # # sharded input + out = f(arr) + self.assertArraysAllClose(out, np.sin(np_inp * 2)) + + out2 = f(np_inp) + self.assertArraysAllClose(out2, np.sin(np_inp * 2)) def test_jit_host_multi_outputs(self): _, s, np_inp, inp = _create_inputs((8, 2), P("x")) From 6c51234f9ca15725b7d4171e475066f77b5e928f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 28 May 2024 19:23:29 -0700 Subject: [PATCH 18/18] Update XLA dependency to use revision http://github.com/openxla/xla/commit/940e3a27542b7ce76666173e7b287aa2a9263916. PiperOrigin-RevId: 638107181 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 030f1e224fb7..a1ff12724d2e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "98db3e8c8f64dede911fd97605f76aaf6ede1153" -XLA_SHA256 = "3bd4b3a121840edfe253a35dafeea9856f78071fb3610a2ce36db832442e8b7e" +XLA_COMMIT = "940e3a27542b7ce76666173e7b287aa2a9263916" +XLA_SHA256 = "bcdc778e5a456839869dea796117b723bdea488075bd9555fe118fd8d6fcf25e" def repo(): tf_http_archive(