Skip to content

Commit

Permalink
Merge pull request #22103 from rajasekharporeddy:testbranch2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647060866
  • Loading branch information
jax authors committed Jun 26, 2024
2 parents 0ed2254 + 2702f21 commit 66287cd
Showing 1 changed file with 67 additions and 10 deletions.
77 changes: 67 additions & 10 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,6 @@ def force(x):
force, x, "The axis argument must be known statically.")


# TODO(jakevdp) change promote_integers default to False
_PROMOTE_INTEGERS_DOC = """
promote_integers : bool, default=True
If True, then integer inputs will be promoted to the widest available integer
dtype, following numpy's behavior. If False, the result will have the same dtype
as the input. ``promote_integers`` is ignored if ``dtype`` is specified.
"""


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
Expand Down Expand Up @@ -309,11 +300,77 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)

@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)

def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
r"""Return product of the array elements over a given axis.
JAX implementation of :func:`numpy.prod`.
Args:
a: Input array.
axis: int or array, default=None. Axis along which the product to be computed.
If None, the product is computed along all the axes.
dtype: The type of the output array. Default=None.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, Default=None. Initial value for the product.
where: int or array, default=None. The elements to be used in the product.
Array should be broadcast compatible to the input.
promote_integers : bool, default=True. If True, then integer inputs will be
promoted to the widest available integer dtype, following numpy's behavior.
If False, the result will have the same dtype as the input.
``promote_integers`` is ignored if ``dtype`` is specified.
out: Unused by JAX.
Returns:
An array of the product along the given axis.
See also:
- :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.
Examples:
By default, ``jnp.prod`` computes along all the axes.
>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 1, 3],
... [2, 1, 3, 1]])
>>> jnp.prod(x)
Array(4320, dtype=int32)
If ``axis=1``, product is computed along axis 1.
>>> jnp.prod(x, axis=1)
Array([24, 30, 6], dtype=int32)
If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.
>>> jnp.prod(x, axis=1, keepdims=True)
Array([[24],
[30],
[ 6]], dtype=int32)
To include only specific elements in the sum, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[4],
[3],
[6]], dtype=int32)
>>> where = jnp.array([[False],
... [False],
... [False]])
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[1],
[1],
[1]], dtype=int32)
"""
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
Expand Down

0 comments on commit 66287cd

Please sign in to comment.