Skip to content

Commit

Permalink
Refactored common upcast for integral-type accumulators
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 23, 2024
1 parent b1cb90c commit a0e30af
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 37 deletions.
8 changes: 2 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,14 +3024,10 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
dtypes.check_user_dtype_supported(dtype, "trace")

a = asarray(a)
a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(int)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int

a = util.promote_dtypes_integral_default(a)[0]
a = moveaxis(a, (axis1, axis2), (-2, -1))

# Mask out the diagonal and reduce.
Expand Down
14 changes: 7 additions & 7 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@

from jax import lax
from jax._src import api
from jax._src import core, config
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements,
promote_dtypes_integral_default, )
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
Expand Down Expand Up @@ -730,11 +731,10 @@ def cumulative_sum(

axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
if dtype is None:
x = promote_dtypes_integral_default(x)[0]
else:
x = x.astype(dtype=dtype)
out = cumsum(x, axis=axis)
if include_initial:
zeros_shape = list(x.shape)
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ def promote_dtypes_complex(*args: ArrayLike) -> list[Array]:
for x in args]


def promote_dtypes_integral_default(*args: ArrayLike) -> list[Array]:
"""Convenience function to apply default promotion to integral accumulators.
Promotes arguments to their corresponding default integral type, or returns
the arguments unchanged."""
def _promote(x: ArrayLike) -> Array:
x = lax.asarray(x)
kind = x.dtype.kind
if kind in {'i', 'u'}:
return x.astype(dtypes._default_types[kind])
return x
return [_promote(x) for x in args]


def _complex_elem_type(dtype: DTypeLike) -> DType:
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
Expand Down
15 changes: 0 additions & 15 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,3 @@ def result_type(*arrays_and_dtypes):
if len(dtypes) == 1:
return dtypes[0]
return functools.reduce(_promote_types, dtypes)


def _promote_to_default_dtype(x):
if x.dtype.kind == 'b':
return x
elif x.dtype.kind == 'i':
return x.astype(jnp.int_)
elif x.dtype.kind == 'u':
return x.astype(jnp.uint)
elif x.dtype.kind == 'f':
return x.astype(jnp.float_)
elif x.dtype.kind == 'c':
return x.astype(jnp.complex_)
else:
raise ValueError(f"Unrecognized {x.dtype=}")
7 changes: 3 additions & 4 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

import jax
from jax.experimental.array_api._data_type_functions import (
_promote_to_default_dtype,
)
from jax._src.numpy.util import promote_dtypes_integral_default

def cholesky(x, /, *, upper=False):
"""
Expand Down Expand Up @@ -140,7 +138,8 @@ def trace(x, /, *, offset=0, dtype=None):
"""
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x.
"""
x = _promote_to_default_dtype(x)
if dtype is None:
x = promote_dtypes_integral_default(x)[0]
return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1)

def vecdot(x1, x2, /, *, axis=-1):
Expand Down
5 changes: 0 additions & 5 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.

import jax
from jax.experimental.array_api._data_type_functions import (
_promote_to_default_dtype,
)


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
Expand All @@ -39,7 +36,6 @@ def min(x, /, *, axis=None, keepdims=False):

def prod(x, /, *, axis=None, dtype=None, keepdims=False):
"""Calculates the product of input array x elements."""
x = _promote_to_default_dtype(x)
return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)


Expand All @@ -50,7 +46,6 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False):

def sum(x, /, *, axis=None, dtype=None, keepdims=False):
"""Calculates the sum of the input array x."""
x = _promote_to_default_dtype(x)
return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)


Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ array_api_tests/test_special_cases.py::test_unary

# fft test suite is buggy as of 83f0bcdc
array_api_tests/test_fft.py

# Pending implementation update for proper dtype promotion behavior,
# see https://github.com/data-apis/array-api-tests/issues/234
array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod
array_api_tests/test_linalg.py::test_trace

0 comments on commit a0e30af

Please sign in to comment.