Skip to content

Commit

Permalink
Add new cumulative_sum function to numpy and array_api
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 16, 2024
1 parent adbb11f commit f111aa3
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 3 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ namespace; they are listed below.
csingle
cumprod
cumsum
cumulative_sum
deg2rad
degrees
delete
Expand Down
36 changes: 35 additions & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jax import lax
from jax._src import api
from jax._src import core
from jax._src import core, config
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
Expand Down Expand Up @@ -708,6 +708,40 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)

@implements(getattr(np, 'cumulative_sum', None))
def cumulative_sum(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
check_arraylike("cumulative_sum", x)
x = lax_internal.asarray(x)
if x.ndim == 0:
raise ValueError(
"The input must be non-scalar to take a cumulative sum, however a "
"scalar value or scalar array was given."
)
if axis is None and x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")
axis = axis or 0
axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
out = cumsum(x, axis=axis)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = lax_internal.concatenate(
[lax_internal.full(zeros_shape, 0, dtype=out.dtype), out],
dimension=axis)
return out

# Quantiles
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
)


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
"""Calculates the cumulative sum of elements in the input array x."""
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)

def max(x, /, *, axis=None, keepdims=False):
"""Calculates the maximum value of the input array x."""
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
count_nonzero as count_nonzero,
cumsum as cumsum,
cumprod as cumprod,
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
median as median,
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
cumproduct = cumprod
def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ...) -> Array: ...
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
dtype: DTypeLike | None = ...,
include_initial: bool = ...) -> Array: ...

def deg2rad(x: ArrayLike, /) -> Array: ...
degrees = rad2deg
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'copysign',
'cos',
'cosh',
'cumulative_sum',
'divide',
'e',
'empty',
Expand Down
64 changes: 64 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,5 +770,69 @@ def test_f16_mean(self, dtype):
self.assertAllClose(expected, actual, atol=0)


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(
range(-len(shape), len(shape))
) + ([None] if len(shape) == 1 else [])],
dtype=all_dtypes,
out_dtype=all_dtypes,
include_initial=[False, True],
)
@jtu.ignore_warning(category=NumpyComplexWarning)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())

def np_mock_op(x, axis=None, dtype=None, include_initial=False):
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
axis = axis or 0
x = x.astype(dtype=dtype or x.dtype)
out = jnp.cumsum(x, axis=axis)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
return out


# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
# input because we rely on JAX-specific casting behavior
args_maker = lambda: [jnp.array(rng(shape, dtype))]
np_op = getattr(np, "cumulative_sum", np_mock_op)
kwargs = dict(axis=axis, include_initial=include_initial)

np_fun = lambda x: np_op(x, dtype=out_dtype, **kwargs)
jnp_fun = lambda x: jnp.cumulative_sum(x, dtype=out_dtype, **kwargs)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

np_fun = lambda x: np_op(x, dtype=None, **kwargs)
jnp_fun = lambda x: jnp.cumulative_sum(x, dtype=None, **kwargs)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@jtu.sample_product(
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
include_initial=[False, True])
def testCumulativeSumErrors(self, shape, dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())
x = rng(shape, dtype)
rank = jnp.asarray(x).ndim
if rank == 0:
msg = r"The input must be non-scalar to take"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)
elif rank > 1:
msg = r"The input array has rank \d*, however"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit f111aa3

Please sign in to comment.