Skip to content

Commit

Permalink
Finalize the deprecation of the arr.device() method
Browse files Browse the repository at this point in the history
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 622940937
  • Loading branch information
Jake VanderPlas authored and jax authors committed Apr 8, 2024
1 parent 7b486f4 commit 391a308
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.


## jaxlib 0.4.27
Expand Down
17 changes: 0 additions & 17 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
Expand All @@ -50,7 +49,6 @@
from jax._src.typing import ArrayLike
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method

deprecations.register(__name__, "device-method")

Shape = tuple[int, ...]
Device = xc.Device
Expand Down Expand Up @@ -471,21 +469,6 @@ def on_device_size_in_bytes(self):
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
return per_shard_size * len(self.sharding.device_set)

# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
if deprecations.is_accelerated(__name__, "device-method"):
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
else:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
single_device, = device_set
return single_device
raise ValueError('Length of devices is greater than 1. '
'Please use `.devices()`.')

def devices(self) -> set[Device]:
self._check_if_deleted()
return self.sharding.device_set
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
from jax._src.array import ArrayImpl
from jax.experimental.array_api._version import __array_api_version__
from jax.sharding import Sharding

from jax._src.lib import xla_extension as xe

Expand All @@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None):
return jax.experimental.array_api


def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
def _to_device(self, device: xe.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
# The type of device is defined by Array.device. In JAX, this is a callable that
# returns a device, so we must handle this case to satisfy the API spec.
return jax.device_put(self, device() if callable(device) else device)
return jax.device_put(self, device)


def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
4 changes: 0 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax import random
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap
Expand Down Expand Up @@ -1019,9 +1018,6 @@ def test_array_impl_attributes(self):

self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
if not deprecations.is_accelerated('jax._src.array', 'device-method'):
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
self.assertEqual(key.device(), key._base_array.device())
self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes(),
key._base_array.on_device_size_in_bytes())
Expand Down

0 comments on commit 391a308

Please sign in to comment.