Skip to content

Commit

Permalink
Add new unstack function to numpy/array_api namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 15, 2024
1 parent 2c85ca6 commit f507dc2
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,12 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis, dtype=dtype)

@util.implements(getattr(np, 'unstack', None))
@partial(jit, static_argnames="axis")
def unstack(x: np.ndarray | Array, /, *, axis: int = 0) -> tuple[Array, ...]:
util.check_arraylike("unstack", x)
return tuple(moveaxis(x, axis, 0))

@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
util.check_arraylike("tile", A)
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
roll as roll,
squeeze as squeeze,
stack as stack,
unstack as unstack,
)

from jax.experimental.array_api._searching_functions import (
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
"""Joins a sequence of arrays along a new axis."""
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Splits an array in a sequence of arrays along the given axis."""
return jax.numpy.unstack(x, axis=axis)
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
unpackbits as unpackbits,
unravel_index as unravel_index,
unsignedinteger as unsignedinteger,
unstack as unstack,
unwrap as unwrap,
vander as vander,
vdot as vdot,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def unpackbits(
) -> Array: ...
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ...
unsignedinteger = _np.unsignedinteger
def unstack(x: _np.ndarray | Array, /, *, axis: int = ...) -> tuple[Array, ...]: ...
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ...,
axis: int = ..., period: ArrayLike = ...) -> Array: ...
def vander(
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
'unique_counts',
'unique_inverse',
'unique_values',
'unstack',
'var',
'vecdot',
'where',
Expand Down
14 changes: 14 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,20 @@ def f():
for a in out]
return f


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape)))],
dtype=all_dtypes,
)
def testUnstack(self, shape, axis, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jnp.array(jnp.unstack(x, axis=axis))
self.assertArraysEqual(jnp.moveaxis(y, 0, axis), x)


@parameterized.parameters(
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
Expand Down

0 comments on commit f507dc2

Please sign in to comment.