Skip to content

Commit

Permalink
Add support for device kwarg in astype, and add matching utility func
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 11, 2024
1 parent 2215021 commit be2dacc
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.27

* New Features
* {func}`jax.numpy.astype` supports new `device` keyword argument.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
Expand Down
34 changes: 28 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import opt_einsum

import jax
from jax import jit
from jax import jit, device_put
from jax import errors
from jax import lax
from jax.sharding import Sharding, SingleDeviceSharding
Expand Down Expand Up @@ -2293,12 +2293,34 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
warnings.simplefilter("ignore", ComplexWarning)
return lax.convert_element_type(out, dtype)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
def _get_device_set(x: ArrayLike | xc.Device | Sharding | None):
if x is None:
return None
elif isinstance(x, Sharding):
return x.device_set
elif isinstance(x, xc.Device):
return {x}
elif hasattr(x, "devices") and not isinstance(x, core.Tracer):
return x.devices()

def _place_array(x: jax.Array, device: xc.Device | Sharding | None = None, copy=None):
# TODO(micky774): Fine tune mechanics in future PRs as we formalize device
# placement semantics
devices = _get_device_set(device)
src_devices = _get_device_set(x)
if devices is not None and src_devices != devices:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source devices "
f"are {src_devices}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
out = device_put(x, device)
else:
out = x
if copy:
return _array_copy(x)
return x
return _array_copy(out)
return out


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down
20 changes: 15 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3843,16 +3843,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
@jtu.sample_product(
change_dtype=[True, False],
copy=[True, False],
change_device=[True, False],
)
def testAstypeCopy(self, change_dtype, copy):
def testAstypeCopy(self, change_dtype, copy, change_device):
if change_device and not jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest(
"Testing device transfer requires at least two available devices."
)

dtype = 'float32' if change_dtype else 'int32'
expect_copy = change_dtype or copy
device = jax.devices("cpu")[0] if change_device else None
expect_copy = change_dtype or copy or change_device
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy)
y = x.astype(dtype, copy=copy, device=device)

assert y.dtype == dtype
y.delete()
assert x.is_deleted() != expect_copy
if change_device:
assert y.devices() == {device}
else:
y.delete()
assert x.is_deleted() != expect_copy

def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
Expand Down

0 comments on commit be2dacc

Please sign in to comment.