Skip to content

Commit

Permalink
Add new unstack function to numpy/array_api namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 15, 2024
1 parent 64775d0 commit 6bdc83c
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ namespace; they are listed below.
unique_values
unpackbits
unravel_index
unstack
unsignedinteger
unwrap
vander
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,18 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)

@util.implements(getattr(np, 'unstack', None))
@partial(jit, static_argnames="axis")
def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
util.check_arraylike("unstack", x)
x = asarray(x)
if x.ndim == 0:
raise ValueError(
"Unstack requires arrays with rank > 0, however a scalar array was "
"passed."
)
return tuple(moveaxis(x, axis, 0))

@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
util.check_arraylike("tile", A)
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
roll as roll,
squeeze as squeeze,
stack as stack,
unstack as unstack,
)

from jax.experimental.array_api._searching_functions import (
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
"""Joins a sequence of arrays along a new axis."""
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Splits an array in a sequence of arrays along the given axis."""
return jax.numpy.unstack(x, axis=axis)
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
unpackbits as unpackbits,
unravel_index as unravel_index,
unsignedinteger as unsignedinteger,
unstack as unstack,
unwrap as unwrap,
vander as vander,
vdot as vdot,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def unpackbits(
) -> Array: ...
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
unsignedinteger = _np.unsignedinteger
def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ...
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
axis: int = ..., period: ArrayLike = ...) -> Array: ...
def vander(
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
'unique_counts',
'unique_inverse',
'unique_values',
'unstack',
'var',
'vecdot',
'where',
Expand Down
21 changes: 21 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ def f():
for a in out]
return f


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
dtype=all_dtypes,
)
def testUnstack(self, shape, axis, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
if jnp.asarray(x).ndim == 0:
with self.assertRaisesRegex(ValueError, "Unstack requires arrays with"):
jnp.unstack(x, axis=axis)
return
y = jnp.unstack(x, axis=axis)
if shape[axis] == 0:
self.assertEqual(y, ())
else:
self.assertArraysEqual(jnp.moveaxis(jnp.array(y), 0, axis), x)


@parameterized.parameters(
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
Expand Down

0 comments on commit 6bdc83c

Please sign in to comment.