Skip to content

Commit

Permalink
Add support for device kwarg in astype to match Array API
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 20, 2024
1 parent 281f5ac commit 2395e76
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 6 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.26

* New Features
* {func}`jax.numpy.astype` supports new `device` keyword argument.

* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
Expand All @@ -28,6 +31,12 @@ Remember to align the itemized text with the first line of an item within a list
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.

* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
To prevent copying when possible, set `copy=False`.

## jaxlib 0.4.26

## jax 0.4.25 (Feb 26, 2024)
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
import numpy as np
import jax
from jax import lax
from jax.sharding import Sharding
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import lax_numpy
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
Expand All @@ -55,15 +57,15 @@
# functions, which can themselves handle instances from any of these classes.


def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
"""Copy the array and cast to a specified dtype.
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :meth:`numpy.ndarray.astype` in
some cases. In particular, the details of float-to-int and int-to-float
casts are implementation dependent.
"""
return lax_numpy.astype(arr, dtype)
return lax_numpy.astype(arr, dtype, copy=copy, device=device)


def _nbytes(arr: ArrayLike) -> int:
Expand Down
24 changes: 20 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,13 +2216,29 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
src_devices = x.devices() if hasattr(x, "devices") and not isinstance(x, core.Tracer) else None
if device is not None and src_devices != {device}:
return device_put(x, device)
arr = _array_copy(x) if copy else x
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x)
if (
src_dtype is not None
and dtypes.isdtype(src_dtype, "complex floating")
and dtypes.isdtype(dtype, ("integral", "real floating"))
):
raise ValueError(
"Casting from complex to non-complex dtypes is not permitted. Please "
"first use jnp.real or jnp.imag to take the real/imaginary component of "
"your input."
)
src_devices = (
x.devices() if hasattr(x, "devices")
and not isinstance(x, core.Tracer) else None
)
arr = x
if device is not None and src_devices != {device}:
arr = device_put(x, device)
elif copy:
arr = _array_copy(x)
return lax.convert_element_type(arr, dtype)


Expand Down
35 changes: 35 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
change_dtype=[True, False],
copy=[True, False],
change_device=[True, False],
)
def testAstypeCopy(self, change_dtype, copy, change_device):
if jax.device_count() == 1 and change_device:
raise unittest.SkipTest(
"Testing device transfer requires at least two available devices."
)

dtype = 'float32' if change_dtype else 'int32'
device = jax.devices()[-1] if change_device else None
expect_copy = change_dtype or copy or change_device
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy, device=device)

assert y.dtype == dtype
if change_device:
assert y.devices() == {device}
else:
y.delete()
get_val = lambda: np.array(x)
err_msg = "Array has been deleted"
if expect_copy:
get_val()
else:
jtu.check_raises(get_val, RuntimeError, err_msg)

def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
complex_downcast = lambda: x.astype('float32')
err_msg = "Casting from complex to non-complex "
jtu.check_raises(complex_downcast, ValueError, err_msg)

def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
Expand Down

0 comments on commit 2395e76

Please sign in to comment.