Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose existing functions in array API namespace #20753

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
ceil as ceil,
clip as clip,
conj as conj,
copysign as copysign,
cos as cos,
cosh as cosh,
divide as divide,
Expand All @@ -139,6 +140,8 @@
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
maximum as maximum,
minimum as minimum,
multiply as multiply,
negative as negative,
not_equal as not_equal,
Expand All @@ -148,6 +151,7 @@
remainder as remainder,
round as round,
sign as sign,
signbit as signbit,
sin as sin,
sinh as sinh,
sqrt as sqrt,
Expand All @@ -168,7 +172,9 @@
concat as concat,
expand_dims as expand_dims,
flip as flip,
moveaxis as moveaxis,
permute_dims as permute_dims,
repeat as repeat,
reshape as reshape,
roll as roll,
squeeze as squeeze,
Expand All @@ -179,6 +185,7 @@
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
searchsorted as searchsorted,
where as where,
)

Expand Down
23 changes: 22 additions & 1 deletion jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
result_type as _result_type,
isdtype as _isdtype,
)
import numpy as np


def _promote_dtypes(name, *args):
Expand Down Expand Up @@ -148,6 +147,11 @@ def conj(x, /):
return jax.numpy.conj(x)


def copysign(x1, x2, /):
"""Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1."""
return jax.numpy.copysign(x1, x2)


def cos(x, /):
"""Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x."""
x, = _promote_dtypes("cos", x)
Expand Down Expand Up @@ -300,6 +304,18 @@ def logical_xor(x1, x2, /):
return jax.numpy.logical_xor(x1, x2)


def maximum(x1, x2, /):
"""Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("maximum", x1, x2)
return jax.numpy.maximum(x1, x2)


def minimum(x1, x2, /):
"""Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("minimum", x1, x2)
return jax.numpy.minimum(x1, x2)


def multiply(x1, x2, /):
"""Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("multiply", x1, x2)
Expand Down Expand Up @@ -356,6 +372,11 @@ def sign(x, /):
return jax.numpy.sign(x)


def signbit(x, /):
"""Determines whether the sign bit is set for each element x_i of the input array x."""
return jax.numpy.signbit(x)


def sin(x, /):
"""Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x."""
x, = _promote_dtypes("sin", x)
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
return jax.numpy.flip(x, axis=axis)


def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array:
"""Moves array axes (dimensions) to new positions, while leaving other axes in their original positions."""
return jax.numpy.moveaxis(x, source, destination)


def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
"""Permutes the axes (dimensions) of an array x."""
return jax.numpy.permute_dims(x, axes=axes)


def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
"""Repeats each element of an array a specified number of times on a per-element basis."""
return jax.numpy.repeat(x, repeats=repeats, axis=axis)


def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
Expand Down
9 changes: 9 additions & 0 deletions jax/experimental/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def nonzero(x, /):
return jax.numpy.nonzero(x)


def searchsorted(x1, x2, /, *, side='left', sorter=None):
"""
Finds the indices into x1 such that, if the corresponding elements in x2
were inserted before the indices, the order of x1, when sorted in ascending
order, would be preserved.
"""
return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter)


def where(condition, x1, x2, /):
"""Returns elements chosen from x1 or x2 depending on condition."""
dtype = _result_type(x1, x2)
Expand Down
7 changes: 7 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'complex64',
'concat',
'conj',
'copysign',
'cos',
'cosh',
'divide',
Expand Down Expand Up @@ -115,9 +116,12 @@
'matmul',
'matrix_transpose',
'max',
'maximum',
'mean',
'meshgrid',
'min',
'minimum',
'moveaxis',
'multiply',
'nan',
'negative',
Expand All @@ -133,11 +137,14 @@
'prod',
'real',
'remainder',
'repeat',
'reshape',
'result_type',
'roll',
'round',
'searchsorted',
'sign',
'signbit',
'sin',
'sinh',
'sort',
Expand Down