From e0fa3cad25ad514c52ae98ae2bbcc97ff69615ed Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 25 Jun 2024 10:38:37 +0200 Subject: [PATCH] Added device kwargs to jnp.linspace, jnp.array, jnp.asarray --- jax/_src/numpy/lax_numpy.py | 62 +++++++++++++------ jax/experimental/array_api/__init__.py | 7 +-- .../array_api/_creation_functions.py | 25 -------- jax/numpy/__init__.pyi | 18 ++++-- tests/lax_numpy_test.py | 27 +++++--- 5 files changed, 77 insertions(+), 62 deletions(-) delete mode 100644 jax/experimental/array_api/_creation_functions.py diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ca6cec379878..6fa62654698c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3205,9 +3205,13 @@ def _supports_buffer_protocol(obj): deprecations.register("jax-numpy-array-none") -@util.implements(np.array, lax_description=_ARRAY_DOC) +@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params=""" +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. +""") def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, - order: str | None = "K", ndmin: int = 0) -> Array: + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -3223,8 +3227,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)): - # Keep the output uncommitted. - return jax.device_put(object) + # Keep the output uncommitted if device is None. + return jax.device_put(object, device=device) # For Python scalar literals, call coerce_to_array to catch any overflow # errors. We don't use dtypes.is_python_scalar because we don't want this @@ -3304,7 +3308,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, out = np.array(object) if copy else np.asarray(object) else: raise TypeError(f"Unexpected input type for array: {type(object)}") - + if device is not None: + out = jax.device_put(out, device=device) out_array: Array = lax_internal._convert_element_type( out, dtype, weak_type=weak_type) if ndmin > ndim(out_array): @@ -3326,6 +3331,9 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: have slightly different behavior than :func:`numpy.astype` in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. +""", extra_params=""" +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. """) def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, @@ -3365,9 +3373,13 @@ def _place_array(x, device=None, copy=None): return x -@util.implements(np.asarray, lax_description=_ARRAY_DOC) +@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params=""" +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. +""") def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, - *, copy: bool | None = None) -> Array: + *, copy: bool | None = None, + device: xc.Device | Sharding | None = None) -> Array: # For copy=False, the array API specifies that we raise a ValueError if the input supports # the buffer protocol but a copy is required. Since array() supports the buffer protocol # via numpy, this is only the case when the default device is not 'cpu' @@ -3380,7 +3392,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, dtypes.check_user_dtype_supported(dtype, "asarray") if dtype is not None: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - return array(a, dtype=dtype, copy=bool(copy), order=order) + return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) @util.implements(np.copy, lax_description=_ARRAY_DOC) @@ -3956,10 +3968,8 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision and then convert it to the desired lower precision. """, extra_params=""" -device : :py:class:`Device`, :py:class:`Sharding`, optional - The (optional) :py:class:`Device`, :py:class:`Sharding`, - representing the device(s) to which created array should be - transferred. If given, then the result is committed to the device(s). +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. """ ) def arange(start: DimSize, stop: DimSize | None = None, @@ -4041,30 +4051,44 @@ def _arange_dynamic( def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array: ... + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: bool, retstep: Literal[True], dtype: DTypeLike | None = None, - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = None, - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array | tuple[Array, Array]: ... -@util.implements(np.linspace) + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@util.implements(np.linspace, extra_params=""" +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. +""") def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array | tuple[Array, Array]: + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, num, endpoint, retstep, dtype, axis) + + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _linspace(start, stop, num, endpoint, retstep, dtype, axis) + if device is not None: + return jax.device_put(output, device=device) + return output @partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis')) def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index e0d8c4ee67f5..a339355fac8d 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -52,6 +52,7 @@ argmax as argmax, argmin as argmin, argsort as argsort, + asarray as asarray, asin as asin, asinh as asinh, atan as atan, @@ -108,6 +109,7 @@ isnan as isnan, less as less, less_equal as less_equal, + linspace as linspace, log as log, log10 as log10, log1p as log1p, @@ -184,11 +186,6 @@ reshape as reshape, ) -from jax.experimental.array_api._creation_functions import ( - asarray as asarray, - linspace as linspace, -) - from jax.experimental.array_api._data_type_functions import ( astype as astype, finfo as finfo, diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py deleted file mode 100644 index 5b9789ed732d..000000000000 --- a/jax/experimental/array_api/_creation_functions.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -import jax.numpy as jnp - - -def asarray(obj, /, *, dtype=None, device=None, copy=None): - return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) - -def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): - return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 6e9c7af5eabb..0b2a8301a186 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -99,7 +99,8 @@ def argwhere( ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, - order: str | None = ..., ndmin: int = ...) -> Array: ... + order: str | None = ..., ndmin: int = ..., + *, device: _Device | _Sharding | None = ...) -> Array: ... def array_equal( a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ... ) -> Array: ... @@ -113,7 +114,8 @@ def array_split( array_str = _np.array_str def asarray( a: Any, dtype: DTypeLike | None = ..., order: str | None = ..., - *, copy: builtins.bool | None = ... + *, copy: builtins.bool | None = ..., + device: _Device | _Sharding | None = ..., ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... @@ -522,22 +524,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = ..., - axis: int = 0) -> Array: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: builtins.bool, retstep: Literal[True], dtype: DTypeLike | None = ..., - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = ..., - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: builtins.bool = False, dtype: DTypeLike | None = ..., - axis: int = 0) -> Array | tuple[Array, Array]: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ... def load(*args: Any, **kwargs: Any) -> Array: ... def log(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 45cc177fbfd1..1b01adc6ccaf 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2992,25 +2992,39 @@ def testArrayCreationWithSharding(self, func, shape, dtype): func=[ lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device), + lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device), ], dtype=default_dtypes, ) - def testArangeEyeWithDevice(self, func, dtype): + def testArangeEyeLinspaceArrayWithDevice(self, func, dtype): device = jax.devices()[-1] - out = func(dtype=dtype, device=device) - self.assertEqual(out.devices(), {device}) + output = func(dtype=dtype, device=device) + if isinstance(output, tuple): + for out in output: + self.assertEqual(out.devices(), {device}) + else: + self.assertEqual(output.devices(), {device}) @jtu.sample_product( func=[ lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device), + lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device), ], dtype=default_dtypes, ) - def testArangeEyeWithSharding(self, func, dtype): + def testArangeEyeLinspaceArrayWithSharding(self, func, dtype): sharding = SingleDeviceSharding(jax.devices()[-1]) - out = func(dtype=dtype, device=sharding) - self.assertEqual(out.sharding, sharding) + output = func(dtype=dtype, device=sharding) + if isinstance(output, tuple): + for out in output: + self.assertEqual(out.sharding, sharding) + else: + self.assertEqual(output.sharding, sharding) @jtu.sample_product( func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], @@ -5984,7 +5998,6 @@ def testWrappedSignaturesMatch(self): 'histogram': ['normed'], 'histogram2d': ['normed'], 'histogramdd': ['normed'], - 'linspace': ['device'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'],