Skip to content

Commit

Permalink
Merge pull request #20754 from Micky774:array-api-hypot
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625035601
  • Loading branch information
jax authors committed Apr 15, 2024
2 parents eabd1bc + 2899213 commit 5f22b12
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.

## jaxlib 0.4.27

Expand Down
21 changes: 16 additions & 5 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import operator
from textwrap import dedent
from typing import Any, Callable, overload
import warnings

import numpy as np

Expand Down Expand Up @@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@implements(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("hypot", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = promote_args_inexact("hypot", x1, x2)

# TODO(micky774): Promote to ValueError when deprecation is complete
# (began 2024-4-14).
if dtypes.issubdtype(x1.dtype, np.complexfloating):
warnings.warn(
"Passing complex-valued inputs to hypot is deprecated and will raise a "
"ValueError in the future. Please convert to real values first, such as "
"by using jnp.real or jnp.imag to take the real or imaginary components "
"respectively.",
DeprecationWarning, stacklevel=2)
x1, x2 = lax.abs(x1), lax.abs(x2)
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1)))))
return _where(idx_inf, _lax_const(x, np.inf), x)


@implements(np.reciprocal, module='numpy')
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
floor_divide as floor_divide,
greater as greater,
greater_equal as greater_equal,
hypot as hypot,
imag as imag,
isfinite as isfinite,
isinf as isinf,
Expand Down
15 changes: 15 additions & 0 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import jax
from jax._src.dtypes import issubdtype
from jax.experimental.array_api._data_type_functions import (
result_type as _result_type,
isdtype as _isdtype,
Expand Down Expand Up @@ -214,6 +215,20 @@ def greater_equal(x1, x2, /):
return jax.numpy.greater_equal(x1, x2)


def hypot(x1, x2, /):
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("hypot", x1, x2)

# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
if issubdtype(x1.dtype, jax.numpy.complexfloating):
raise ValueError(
"hypot does not support complex-valued inputs. Please convert to real "
"values first, such as by using jnp.real or jnp.imag to take the real "
"or imaginary components respectively.")
return jax.numpy.hypot(x1, x2)


def imag(x, /):
"""Returns the imaginary component of a complex number for each element x_i of the input array x."""
x, = _promote_dtypes("imag", x)
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
'full_like',
'greater',
'greater_equal',
'hypot',
'iinfo',
'imag',
'inf',
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
Expand Down
20 changes: 20 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,26 @@ def testClipComplexInputDeprecation(self, shape):
jnp.clip(x, max=jnp.array([-1+5j]))


# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testHypotComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Passing complex-valued inputs to hypot"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.hypot(x, x)

with self.assertWarns(DeprecationWarning, msg=msg):
y = jnp.ones_like(x)
jnp.hypot(x, y)


@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
Expand Down

0 comments on commit 5f22b12

Please sign in to comment.