Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation of jax.numpy.clip #22193

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,48 +2246,65 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
return _split("array_split", ary, indices_or_sections, axis=axis)


_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
arr: ArrayLike | None = None,
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
a: ArrayLike | DeprecatedArg = DeprecatedArg(),
a_min: ArrayLike | None | DeprecatedArg = DeprecatedArg(),
a_max: ArrayLike | None | DeprecatedArg = DeprecatedArg()
) -> Array:
"""Clip array values to a specified range.

JAX implementation of :func:`numpy.clip`.

Args:
arr: N-dimensional array to be clipped.
min: optional minimum value of the clipped range; if ``None`` (default) then
result will not be clipped to any minimum value. If specified, it should be
broadcast-compatible with ``arr`` and ``max``.
max: optional maximum value of the clipped range; if ``None`` (default) then
result will not be clipped to any maximum value. If specified, it should be
broadcast-compatible with ``arr`` and ``min``.
a: deprecated alias of the ``arr`` argument. Will result in a
:class:`DeprecationWarning` if used.
a_min: deprecated alias of the ``min`` argument. Will result in a
:class:`DeprecationWarning` if used.
a_max: deprecated alias of the ``max`` argument. Will result in a
:class:`DeprecationWarning` if used.

Returns:
An array containing values from ``arr``, with values smaller than ``min`` set
to ``min``, and values larger than ``max`` set to ``max``.

See also:
- :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays.
- :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays.

Examples:
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
"""
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
arr = a if not isinstance(a, DeprecatedArg) else arr
if arr is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'arr', 'min' or 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)

util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
util.check_arraylike("clip", arr)
if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
Expand All @@ -2298,10 +2315,10 @@ def clip(
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
arr = ufuncs.maximum(min, arr)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)
arr = ufuncs.minimum(max, arr)
return asarray(arr)

@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
Expand Down