Skip to content

Commit

Permalink
Add support for device kwarg in astype to match Array API
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 22, 2024
1 parent d7e5dde commit 6856b94
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 10 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.26

* New Features
* {func}`jax.numpy.astype` supports new `device` keyword argument.

* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
Expand All @@ -30,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.

* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
To prevent copying when possible, set `copy=False`.

## jaxlib 0.4.26

* Changes
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
import numpy as np
import jax
from jax import lax
from jax.sharding import Sharding
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import lax_numpy
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
Expand All @@ -55,15 +57,15 @@
# functions, which can themselves handle instances from any of these classes.


def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
"""Copy the array and cast to a specified dtype.
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :meth:`numpy.ndarray.astype` in
some cases. In particular, the details of float-to-int and int-to-float
casts are implementation dependent.
"""
return lax_numpy.astype(arr, dtype)
return lax_numpy.astype(arr, dtype, copy=copy, device=device)


def _nbytes(arr: ArrayLike) -> int:
Expand Down
35 changes: 30 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import opt_einsum

import jax
from jax import jit
from jax import jit, device_put
from jax import errors
from jax import lax
from jax.sharding import Sharding, SingleDeviceSharding
Expand Down Expand Up @@ -2209,19 +2209,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
else:
return x


@util.implements(getattr(np, "astype", None), lax_description="""
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
del copy # unused in JAX
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x)
if (
src_dtype is not None
and dtypes.isdtype(src_dtype, "complex floating")
and dtypes.isdtype(dtype, ("integral", "real floating"))
):
warnings.warn(
"Casting from complex to non-complex dtypes will soon raise a ValueError. "
"Please first use jnp.real or jnp.imag to take the real/imaginary "
"component of your input.",
DeprecationWarning, stacklevel=2
)
src_devices = (
x.devices() if hasattr(x, "devices")
and not isinstance(x, core.Tracer) else None
)
arr = x
if device is not None and src_devices != {device}:
arr = device_put(x, device)
elif copy:
arr = _array_copy(x)

# We offer a more specific warning than the usual ComplexWarning so we prefer
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return lax.convert_element_type(arr, dtype)



@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down
20 changes: 18 additions & 2 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import builtins
import functools
from typing import NamedTuple
import jax
import jax.numpy as jnp


from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src import dtypes as _dtypes
from jax.experimental.array_api._dtypes import (
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
float32, float64, complex64, complex128
Expand Down Expand Up @@ -124,8 +129,19 @@ def _promote_types(t1, t2):
raise ValueError("No promotion path for {t1} & {t2}")


def astype(x, dtype, /, *, copy=True):
return jnp.array(x, dtype=dtype, copy=copy)
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
if (
src_dtype is not None
and _dtypes.isdtype(src_dtype, "complex floating")
and _dtypes.isdtype(dtype, ("integral", "real floating"))
):
raise ValueError(
"Casting from complex to non-complex dtypes is not permitted. Please "
"first use jnp.real or jnp.imag to take the real/imaginary component of "
"your input."
)
return jnp.astype(x, dtype, copy=copy, device=device)


def can_cast(from_, to, /):
Expand Down
4 changes: 3 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ from jax._src import dtypes as _dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
Expand Down Expand Up @@ -112,7 +114,7 @@ def asarray(
) -> Array: ...
def asin(x: ArrayLike, /) -> Array: ...
def asinh(x: ArrayLike, /) -> Array: ...
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
def atan(x: ArrayLike, /) -> Array: ...
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def atanh(x: ArrayLike, /) -> Array: ...
Expand Down
35 changes: 35 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
change_dtype=[True, False],
copy=[True, False],
change_device=[True, False],
)
def testAstypeCopy(self, change_dtype, copy, change_device):
if jax.device_count() == 1 and change_device:
raise unittest.SkipTest(
"Testing device transfer requires at least two available devices."
)

dtype = 'float32' if change_dtype else 'int32'
device = jax.devices()[-1] if change_device else None
expect_copy = change_dtype or copy or change_device
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy, device=device)

assert y.dtype == dtype
if change_device:
assert y.devices() == {device}
else:
y.delete()
get_val = lambda: np.array(x)
err_msg = "Array has been deleted"
if expect_copy:
get_val()
else:
jtu.check_raises(get_val, RuntimeError, err_msg)

def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
msg = "Casting from complex to non-complex dtypes will soon raise "
with self.assertWarns(DeprecationWarning, msg=msg):
x.astype('float32')

def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
Expand Down

0 comments on commit 6856b94

Please sign in to comment.