diff --git a/CHANGELOG.md b/CHANGELOG.md index e753ece004c2..bf9f953eb773 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list * {mod}`jax.random` APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}`jax.vmap` in such cases. + * In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been + renamed to `a` and `b` for consistency with other `beta` APIs. * New Functionality * Added {func}`jax.experimental.Exported.in_shardings_jax` to construct diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index c6d73524d221..511c4c70b591 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -88,6 +88,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index e9649f45d32e..7d9840524665 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -82,6 +82,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 606b042eb4aa..155a97d1e78d 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -82,6 +82,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 diff --git a/build/requirements_lock_3_9.txt b/build/requirements_lock_3_9.txt index 105554c64633..773729380a1e 100644 --- a/build/requirements_lock_3_9.txt +++ b/build/requirements_lock_3_9.txt @@ -88,6 +88,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 431319db6af8..d7f5d95ec29b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k: integer specifying the number of top entries. Returns: - values: array containing the top k values along the last axis. - indices: array containing the indices corresponding to values. + A tuple ``(values, indices)`` where + + - ``values`` is an array containing the top k values along the last axis. + - ``indices`` is an array containing the indices corresponding to values. See also: - - :func:`jax.lax.approx_max_k` - - :func:`jax.lax.approx_min_k` + - :func:`jax.lax.approx_max_k` + - :func:`jax.lax.approx_min_k` + + Example: + Find the largest three values, and their indices, within an array: + + >>> x = jnp.array([9., 3., 6., 4., 10.]) + >>> values, indices = jax.lax.top_k(x, 3) + >>> values + Array([10., 9., 6.], dtype=float32) + >>> indices + Array([4, 0, 2], dtype=int32) """ if core.is_constant_dim(k): k = int(k) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index e7ff5136b474..459605e07969 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -16,7 +16,8 @@ from functools import partial import operator -from typing import cast, Any +from typing import cast, overload, Any +import warnings import numpy as np @@ -186,8 +187,16 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array: n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) +@overload +def beta(a: ArrayLike, b: ArrayLike) -> Array: ... -def beta(x: ArrayLike, y: ArrayLike) -> Array: +@overload +def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ... + +@overload +def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ... + +def beta(*args, **kwds): r"""The beta function JAX implementation of :obj:`scipy.special.beta`. @@ -209,9 +218,27 @@ def beta(x: ArrayLike, y: ArrayLike) -> Array: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.betaln` """ - x, y = promote_args_inexact("beta", x, y) - sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y) - return sign * lax.exp(betaln(x, y)) + # TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10 + if 'x' in kwds: + warnings.warn("The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead.", + category=DeprecationWarning, stacklevel=2) + if 'a' in kwds: + raise TypeError("beta() got both parameter 'a' and parameter 'x'.") + kwds['a'] = kwds.pop('x') + if 'y' in kwds: + warnings.warn("The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead.", + category=DeprecationWarning, stacklevel=2) + if 'b' in kwds: + raise TypeError("beta() got both parameter 'b' and parameter 'y'.") + kwds['b'] = kwds.pop('y') + if extra := kwds.keys() - {'a', 'b'}: + raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}") + return _beta(*args, **kwds) + +def _beta(a, b): + a, b = promote_args_inexact("beta", a, b) + sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) + return sign * lax.exp(betaln(a, b)) def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index abc7cc595cd6..df00093fabe6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -189,6 +189,9 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; } + if (divisor == 1) { + return true; + } if (auto assume_op = value.getDefiningOp()) { return assume_op.getMultiple() % divisor == 0; } diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f950d1048df4..11ee38a37b16 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -237,6 +237,24 @@ def testRelEntrExtremeValues(self): self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol) + def testBetaParameterDeprecation(self): + with self.assertNoWarnings(): + lsp_special.beta(1, 1) + lsp_special.beta(1, b=1) + lsp_special.beta(a=1, b=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(a=1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, b=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, y=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(1, x=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(b=1, y=1) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())