Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new unstack function to numpy/array_api namespaces #20755

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ namespace; they are listed below.
unique_values
unpackbits
unravel_index
unstack
unsignedinteger
unwrap
vander
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)

@util.implements(getattr(np, 'unstack', None))
@partial(jit, static_argnames="axis")
def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
util.check_arraylike("unstack", x)
x = asarray(x)
if x.ndim == 0:
raise ValueError(
"Unstack requires arrays with rank > 0, however a scalar array was "
"passed."
)
return tuple(moveaxis(x, axis, 0))

@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
util.check_arraylike("tile", A)
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
roll as roll,
squeeze as squeeze,
stack as stack,
unstack as unstack,
)

from jax.experimental.array_api._searching_functions import (
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
"""Joins a sequence of arrays along a new axis."""
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Splits an array in a sequence of arrays along the given axis."""
return jax.numpy.unstack(x, axis=axis)
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
unpackbits as unpackbits,
unravel_index as unravel_index,
unsignedinteger as unsignedinteger,
unstack as unstack,
unwrap as unwrap,
vander as vander,
vdot as vdot,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def unpackbits(
) -> Array: ...
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
unsignedinteger = _np.unsignedinteger
def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ...
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
axis: int = ..., period: ArrayLike = ...) -> Array: ...
def vander(
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
'unique_counts',
'unique_inverse',
'unique_values',
'unstack',
'var',
'vecdot',
'where',
Expand Down
21 changes: 21 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ def f():
for a in out]
return f


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
dtype=all_dtypes,
)
def testUnstack(self, shape, axis, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
if jnp.asarray(x).ndim == 0:
with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"):
jnp.unstack(x, axis=axis)
return
y = jnp.unstack(x, axis=axis)
if shape[axis] == 0:
self.assertEqual(y, ())
else:
self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x)


@parameterized.parameters(
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
Expand Down