From 7c7c3b897063955c714bc9c0b364ef4260679ade Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 24 May 2024 21:32:32 +0000 Subject: [PATCH] Add support for device kwarg in astype, and add matching utility func --- CHANGELOG.md | 3 ++ jax/_src/dlpack.py | 19 ++---------- jax/_src/numpy/lax_numpy.py | 14 ++++----- jax/_src/numpy/util.py | 25 +++++++++++++++- jax/experimental/array_api/__init__.py | 6 ++++ tests/array_interoperability_test.py | 7 +++-- tests/lax_numpy_test.py | 40 ++++++++++++++++++++------ 7 files changed, 77 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9f937d7ee17..9a0e727ef6c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.28 (May 9, 2024) +* New Functionality + * {func}`jax.numpy.astype` supports a new `device` keyword argument. + * Bug fixes * Reverted a change to `make_jaxpr` that was breaking Equinox (#21116). diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 386123ae61f0..e4c72c7690e3 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -24,6 +24,7 @@ from jax._src.lib import xla_client from jax._src.typing import Array, DLDeviceType from jax._src.sharding import Sharding +from jax._src.numpy.util import _place_array DLPACK_VERSION = (0, 8) MIN_DLPACK_VERSION = (0, 5) @@ -148,19 +149,6 @@ def to_dlpack(x: Array, stream: int | Any | None = None, f"version ({max_version}) was requested." ) -def _place_array(_arr, device, dlpack_device, copy): - if device and dlpack_device != device: - if copy is not None and not copy: - raise ValueError( - f"Specified {device=} which requires a copy since the source device " - f"is {repr(dlpack_device)}, however copy=False. Set copy=True or " - "copy=None to perform the requested operation." - ) - else: - return device_put(_arr, device) - if copy: - return jnp.array(_arr, copy=True) - return _arr def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None, copy: bool | None = None): @@ -194,8 +182,7 @@ def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None, _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend)) - dlpack_device, = _arr.devices() - return _place_array(_arr, device, dlpack_device, copy) + return _place_array(_arr, device, copy) def _from_dlpack(external_array, device: xla_client.Device | None = None, copy: bool | None = None): @@ -226,7 +213,7 @@ def _from_dlpack(external_array, device: xla_client.Device | None = None, _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, dlpack_device, stream)) - return _place_array(_arr, device, dlpack_device, copy) + return _place_array(_arr, device, copy) def from_dlpack(external_array, device: xla_client.Device | Sharding | None = None, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 74b8b64cb018..9fd8a8eb94ee 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2853,18 +2853,14 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, # to issue our warning. with warnings.catch_warnings(): warnings.simplefilter("ignore", ComplexWarning) - return _place_array( + return util._place_array( lax.convert_element_type(x_arr, dtype), - device=device, copy=copy, + device=device, + # We translate between array API semantics of copy in _place_array, and + # the NumPy semantics of copy in astype. + copy=True if copy else None, ) -def _place_array(x, device=None, copy=None): - # TODO(micky774): Implement in future PRs as we formalize device placement - # semantics - if copy: - return _array_copy(x) - return x - @util.implements(np.asarray, lax_description=_ARRAY_DOC) def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index fb3b7e4e9dc9..62d82cce4b4d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -18,14 +18,16 @@ import re import textwrap from typing import Any, Callable, NamedTuple, TypeVar - import warnings +from jax.sharding import Sharding + from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src.lax import lax +from jax._src.lib import xla_client as xc from jax._src.util import safe_zip, safe_map from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape @@ -117,6 +119,27 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]: return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} +def _place_array(x: Array, device: xc.Device | Sharding | None = None, copy=None) -> Array: + """Helper utility for copying an array, or placing it on a device or sharding. + + This utility uses `jax.device_put` for device placement. + """ + out = x + if device is not None: + # TODO(micky774): Add check to avoid error if no actual device transfer is + # necessary + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy, however copy=False. Set " + "copy=True or copy=None to perform the requested operation." + ) + out = api.device_put(out, device) + + # TODO(micky774): Avoid copy if data has already been copied via device + # transfer + return lax._array_copy(out) if copy else out + + def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index f7375a80fa8a..5ed70ba8e58a 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -32,6 +32,12 @@ and implements most of the API listed in the standard. .. _Python array API standard: https://data-apis.org/array-api/latest/ + + +Note that JAX may not always strictly adhere to array API device semantics when +using ``jax.jit``. In particular, specifying the ``device`` argument is +equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on +device placement, see the documentation of ``jax.device_put`` for more details. """ from __future__ import annotations diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index b555576b3261..4f5c433aeb9b 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -195,13 +195,14 @@ def testTensorFlowToJaxInt64(self): shape=all_shapes, dtype=numpy_dtypes, copy=[False, True], + device_transfer=[False, True], ) - def testNumpyToJax(self, shape, dtype, copy): + def testNumpyToJax(self, shape, dtype, copy, device_transfer): rng = jtu.rand_default(self.rng()) x_np = rng(shape, dtype) - device = jax.devices()[0] + device = jax.devices()[0] if device_transfer else None _from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy) - if jax.default_backend() == 'gpu' and not copy: + if device_transfer and not copy: self.assertRaisesRegex( ValueError, r"Specified .* which requires a copy", diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f46933c169c9..f38abadd2c78 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -41,7 +41,7 @@ import jax.ops from jax import lax from jax import numpy as jnp -from jax.sharding import SingleDeviceSharding +from jax.sharding import SingleDeviceSharding, PartitionSpec as P from jax.test_util import check_grads from jax._src import array @@ -3931,19 +3931,43 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): self._CompileAndCheck(jnp_op, args_maker) @jtu.sample_product( - change_dtype=[True, False], + [dict(dtype=dtype, new_dtype=new_dtype) + for dtype in all_dtypes + for new_dtype in ( + complex_dtypes + if np.issubdtype(dtype, np.complexfloating) + else all_dtypes + ) + ], + shape=array_shapes, copy=[True, False], + device_type=[None, "single", "shard"], ) - def testAstypeCopy(self, change_dtype, copy): - dtype = 'float32' if change_dtype else 'int32' - expect_copy = change_dtype or copy - x = jnp.arange(5, dtype='int32') - y = x.astype(dtype, copy=copy) + @jtu.run_on_devices("gpu") + def testAstypePlacement(self, shape, dtype, new_dtype, copy, device_type): + rng = jtu.rand_default(self.rng()) + x = jnp.asarray(rng(shape, dtype)) + + if device_type is None: + device = None + expected_sharding = x.sharding + elif device_type == "single": + device = jax.devices("cpu")[0] + expected_sharding = SingleDeviceSharding(device) + else: + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + device = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) + expected_sharding = device + + expect_copy = (dtype != new_dtype) or copy or device - self.assertEqual(y.dtype, dtype) + y = x.astype(new_dtype, copy=copy, device=device) + self.assertEqual(y.dtype, new_dtype) + self.assertEqual(y.sharding, expected_sharding) y.delete() self.assertNotEqual(x.is_deleted(), expect_copy) + def testAstypeComplexDowncast(self): x = jnp.array(2.0+1.5j, dtype='complex64') msg = "Casting from complex to non-complex dtypes will soon raise "