Skip to content

Commit

Permalink
Merge branch 'google:main' into lru-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jun 11, 2024
2 parents dfb947e + c5761b7 commit d38c307
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
* In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been
renamed to `a` and `b` for consistency with other `beta` APIs.

* New Functionality
* Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
Expand Down
4 changes: 4 additions & 0 deletions build/requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ execnet==2.1.1 \
--hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \
--hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3
# via pytest-xdist
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
# via -r build/test-requirements.txt
flatbuffers==24.3.25 \
--hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \
--hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4
Expand Down
4 changes: 4 additions & 0 deletions build/requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ execnet==2.1.1 \
--hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \
--hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3
# via pytest-xdist
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
# via -r build/test-requirements.txt
flatbuffers==24.3.25 \
--hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \
--hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4
Expand Down
4 changes: 4 additions & 0 deletions build/requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ execnet==2.1.1 \
--hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \
--hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3
# via pytest-xdist
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
# via -r build/test-requirements.txt
flatbuffers==24.3.25 \
--hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \
--hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4
Expand Down
4 changes: 4 additions & 0 deletions build/requirements_lock_3_9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ execnet==2.1.1 \
--hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \
--hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3
# via pytest-xdist
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
# via -r build/test-requirements.txt
flatbuffers==24.3.25 \
--hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \
--hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4
Expand Down
20 changes: 16 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
k: integer specifying the number of top entries.
Returns:
values: array containing the top k values along the last axis.
indices: array containing the indices corresponding to values.
A tuple ``(values, indices)`` where
- ``values`` is an array containing the top k values along the last axis.
- ``indices`` is an array containing the indices corresponding to values.
See also:
- :func:`jax.lax.approx_max_k`
- :func:`jax.lax.approx_min_k`
- :func:`jax.lax.approx_max_k`
- :func:`jax.lax.approx_min_k`
Example:
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10., 9., 6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)
"""
if core.is_constant_dim(k):
k = int(k)
Expand Down
37 changes: 32 additions & 5 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from functools import partial
import operator
from typing import cast, Any
from typing import cast, overload, Any
import warnings

import numpy as np

Expand Down Expand Up @@ -186,8 +187,16 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array:
n, = promote_args_inexact("factorial", n)
return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1)))

@overload
def beta(a: ArrayLike, b: ArrayLike) -> Array: ...

def beta(x: ArrayLike, y: ArrayLike) -> Array:
@overload
def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ...

@overload
def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ...

def beta(*args, **kwds):
r"""The beta function
JAX implementation of :obj:`scipy.special.beta`.
Expand All @@ -209,9 +218,27 @@ def beta(x: ArrayLike, y: ArrayLike) -> Array:
- :func:`jax.scipy.special.gamma`
- :func:`jax.scipy.special.betaln`
"""
x, y = promote_args_inexact("beta", x, y)
sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y)
return sign * lax.exp(betaln(x, y))
# TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10
if 'x' in kwds:
warnings.warn("The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead.",
category=DeprecationWarning, stacklevel=2)
if 'a' in kwds:
raise TypeError("beta() got both parameter 'a' and parameter 'x'.")
kwds['a'] = kwds.pop('x')
if 'y' in kwds:
warnings.warn("The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead.",
category=DeprecationWarning, stacklevel=2)
if 'b' in kwds:
raise TypeError("beta() got both parameter 'b' and parameter 'y'.")
kwds['b'] = kwds.pop('y')
if extra := kwds.keys() - {'a', 'b'}:
raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}")
return _beta(*args, **kwds)

def _beta(a, b):
a, b = promote_args_inexact("beta", a, b)
sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b)
return sign * lax.exp(betaln(a, b))


def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) {
if (fuel <= 0) {
return false;
}
if (divisor == 1) {
return true;
}
if (auto assume_op = value.getDefiningOp<tpu::AssumeMultipleOp>()) {
return assume_op.getMultiple() % divisor == 0;
}
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,24 @@ def testRelEntrExtremeValues(self):
self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol)

def testBetaParameterDeprecation(self):
with self.assertNoWarnings():
lsp_special.beta(1, 1)
lsp_special.beta(1, b=1)
lsp_special.beta(a=1, b=1)
with self.assertWarns(DeprecationWarning):
lsp_special.beta(1, y=1)
with self.assertWarns(DeprecationWarning):
lsp_special.beta(a=1, y=1)
with self.assertWarns(DeprecationWarning):
lsp_special.beta(x=1, b=1)
with self.assertWarns(DeprecationWarning):
lsp_special.beta(x=1, y=1)
with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning):
lsp_special.beta(1, x=1)
with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning):
lsp_special.beta(b=1, y=1)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit d38c307

Please sign in to comment.