Skip to content

Commit

Permalink
Add new cumulative_sum function to numpy and array_api namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 15, 2024
1 parent 2c85ca6 commit f5fa5c1
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.cumulative_sum`, following the addition of this
function in the array API 2023 standard, soon to be adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
Expand Down
31 changes: 31 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5567,3 +5567,34 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
else:
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)


@util.implements(getattr(np, 'cumulative_sum', None))
def cumulative_sum(
x: Array, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
if isscalar(x) or x.ndim == 0:
raise ValueError(
"The input must be non-scalar to take a cumulative sum, however a "
"scalar value or scalar array was given."
)
if axis is None and x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")
axis = axis or 0
util.check_arraylike("cumulative_sum", x)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.dtype(dtypes._default_types[kind])

out = reductions.cumsum(x, axis=axis, dtype=dtype)
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
if include_initial:
out = concat([zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
return out
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
)


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
"""Calculates the cumulative sum of elements in the input array x."""
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)

def max(x, /, *, axis=None, keepdims=False):
"""Calculates the maximum value of the input array x."""
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
cov as cov,
cross as cross,
csingle as csingle,
cumulative_sum as cumulative_sum,
delete as delete,
diag as diag,
diagflat as diagflat,
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'conj',
'cos',
'cosh',
'cumulative_sum',
'divide',
'e',
'empty',
Expand Down
53 changes: 53 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,59 @@ def np_fun(x):
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(range(-len(shape), len(shape))) + [None] if len(shape) == 1],
dtype=all_dtypes,
out_dtype=all_dtypes + [None],
include_initial=[False, True],
)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())
x = rng(shape, dtype)
out = jnp.cumulative_sum(x, dtype=out_dtype, include_initial=include_initial)

target_dtype = out_dtype or x.dtype
kind = x.dtype.kind
if (out_dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
target_dtype = dtypes.dtype(dtypes._default_types[kind])
assert out.dtype == target_dtype

_axis = axis or 0
target_shape = list(x.shape)
if include_initial:
target_shape[_axis] += 1
assert out.shape == tuple(target_shape)

target = jnp.cumsum(x, axis=_axis, dtype=out.dtype)
if include_initial:
zeros_shape = target_shape
zeros_shape[_axis] = 1
target = jnp.concat([jnp.zeros(target_shape, dtype=out.dtype), target])
self.assertArraysEqual(out, target)


@jtu.sample_product(
shape=all_shapes, dtype=all_dtypes,
include_initial=[False, True])
def testCumulativeSumErrors(self, shape, dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())
x = rng(shape, dtype)
if jnp.isscalar(x) or x.ndim == 0:
msg = r"The input must be non-scalar to take"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)
elif x.ndim > 1:
msg = r"The input array has rank \d*, however"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)



@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
Expand Down

0 comments on commit f5fa5c1

Please sign in to comment.