Skip to content

Commit

Permalink
Refactored common upcast for integral-type accumulators
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed May 2, 2024
1 parent 187b2ac commit cd288ee
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 45 deletions.
7 changes: 0 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,13 +3364,6 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtypes.check_user_dtype_supported(dtype, "trace")

a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(int)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int

a = moveaxis(a, (axis1, axis2), (-2, -1))

# Mask out the diagonal and reduce.
Expand Down
56 changes: 33 additions & 23 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jax import lax
from jax._src import api
from jax._src import core, config
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
Expand Down Expand Up @@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
return np.dtype('float32')
return np.dtype(dtype)

def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(dtype, np.bool_):
return dtypes.int_
elif dtypes.issubdtype(dtype, np.unsignedinteger):
if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits:
return dtypes.uint
elif dtypes.issubdtype(dtype, np.integer):
if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits:
return dtypes.int_
return dtype


ReductionOp = Callable[[Any, Any], Any]

def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
Expand Down Expand Up @@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
result_dtype = dtype or dtypes.dtype(a)

if dtype is None and promote_integers:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(result_dtype, np.bool_):
result_dtype = dtypes.int_
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
result_dtype = dtypes.uint
elif dtypes.issubdtype(result_dtype, np.integer):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
result_dtype = dtypes.int_
result_dtype = _promote_integer_dtype(result_dtype)

result_dtype = dtypes.canonicalize_dtype(result_dtype)

Expand Down Expand Up @@ -653,7 +658,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:

class CumulativeReduction(Protocol):
def __call__(self, a: ArrayLike, axis: Axis = None,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
dtype: DTypeLike | None = None, out: None = None,
promote_integers: bool = False) -> Array: ...


# TODO(jakevdp): should we change these semantics to match those of numpy?
Expand All @@ -667,12 +673,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
@implements(np_reduction, skip_params=['out'],
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
dtype: DTypeLike | None = None, out: None = None,
promote_integers: bool = False) -> Array:
return _cumulative_reduction(
a, _ensure_optional_axes(axis), dtype,
out, promote_integers=promote_integers
)

@partial(api.jit, static_argnames=('axis', 'dtype'))
@partial(api.jit, static_argnames=('axis', 'dtype', 'promote_integers'))
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
dtype: DTypeLike | None = None, out: None = None,
promote_integers: bool = False) -> Array:
check_arraylike(np_reduction.__name__, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
Expand All @@ -691,11 +702,15 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

if not dtype and dtypes.dtype(a) == np.bool_:
result_type = dtype or dtypes.dtype(a)
if result_type == np.bool_:
dtype = dtypes.canonicalize_dtype(dtypes.int_)
elif dtype is None and promote_integers:
dtype = _promote_integer_dtype(result_type)
if dtype:
a = lax.convert_element_type(a, dtype)


return reduction(a, axis)

return cumulative_reduction
Expand Down Expand Up @@ -730,12 +745,7 @@ def cumulative_sum(

axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
out = cumsum(x, axis=axis)
out = cumsum(x, axis=axis, dtype=dtype, promote_integers=True)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
Expand Down
15 changes: 0 additions & 15 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,3 @@ def finfo(type, /) -> FInfo:
smallest_normal=float(info.smallest_normal),
dtype=jnp.dtype(type)
)

# TODO(micky774): Update utility to only promote integral types
def _promote_to_default_dtype(x):
if x.dtype.kind == 'b':
return x
elif x.dtype.kind == 'i':
return x.astype(jnp.int_)
elif x.dtype.kind == 'u':
return x.astype(jnp.uint)
elif x.dtype.kind == 'f':
return x.astype(jnp.float_)
elif x.dtype.kind == 'c':
return x.astype(jnp.complex_)
else:
raise ValueError(f"Unrecognized {x.dtype=}")

0 comments on commit cd288ee

Please sign in to comment.