Skip to content

Commit

Permalink
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 28, 2024
1 parent 24b42ee commit e0fa3ca
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 62 deletions.
62 changes: 43 additions & 19 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3205,9 +3205,13 @@ def _supports_buffer_protocol(obj):

deprecations.register("jax-numpy-array-none")

@util.implements(np.array, lax_description=_ARRAY_DOC)
@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params="""
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0) -> Array:
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")

Expand All @@ -3223,8 +3227,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)):
# Keep the output uncommitted.
return jax.device_put(object)
# Keep the output uncommitted if device is None.
return jax.device_put(object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
Expand Down Expand Up @@ -3304,7 +3308,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
out = np.array(object) if copy else np.asarray(object)
else:
raise TypeError(f"Unexpected input type for array: {type(object)}")

if device is not None:
out = jax.device_put(out, device=device)
out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
Expand All @@ -3326,6 +3331,9 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""", extra_params="""
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None,
/, *, copy: bool = False,
Expand Down Expand Up @@ -3365,9 +3373,13 @@ def _place_array(x, device=None, copy=None):
return x


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params="""
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
*, copy: bool | None = None) -> Array:
*, copy: bool | None = None,
device: xc.Device | Sharding | None = None) -> Array:
# For copy=False, the array API specifies that we raise a ValueError if the input supports
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
# via numpy, this is only the case when the default device is not 'cpu'
Expand All @@ -3380,7 +3392,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
dtypes.check_user_dtype_supported(dtype, "asarray")
if dtype is not None:
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
return array(a, dtype=dtype, copy=bool(copy), order=order)
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)


@util.implements(np.copy, lax_description=_ARRAY_DOC)
Expand Down Expand Up @@ -3956,10 +3968,8 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
and then convert it to the desired lower precision.
""", extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
"""
)
def arange(start: DimSize, stop: DimSize | None = None,
Expand Down Expand Up @@ -4041,30 +4051,44 @@ def _arange_dynamic(
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace)
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace, extra_params="""
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)

# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _linspace(start, stop, num, endpoint, retstep, dtype, axis)
if device is not None:
return jax.device_put(output, device=device)
return output

@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
Expand Down
7 changes: 2 additions & 5 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
argmax as argmax,
argmin as argmin,
argsort as argsort,
asarray as asarray,
asin as asin,
asinh as asinh,
atan as atan,
Expand Down Expand Up @@ -108,6 +109,7 @@
isnan as isnan,
less as less,
less_equal as less_equal,
linspace as linspace,
log as log,
log10 as log10,
log1p as log1p,
Expand Down Expand Up @@ -184,11 +186,6 @@
reshape as reshape,
)

from jax.experimental.array_api._creation_functions import (
asarray as asarray,
linspace as linspace,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
finfo as finfo,
Expand Down
25 changes: 0 additions & 25 deletions jax/experimental/array_api/_creation_functions.py

This file was deleted.

18 changes: 12 additions & 6 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def argwhere(
) -> Array: ...
around = round
def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True,
order: str | None = ..., ndmin: int = ...) -> Array: ...
order: str | None = ..., ndmin: int = ...,
*, device: _Device | _Sharding | None = ...) -> Array: ...
def array_equal(
a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ...
) -> Array: ...
Expand All @@ -113,7 +114,8 @@ def array_split(
array_str = _np.array_str
def asarray(
a: Any, dtype: DTypeLike | None = ..., order: str | None = ...,
*, copy: builtins.bool | None = ...
*, copy: builtins.bool | None = ...,
device: _Device | _Sharding | None = ...,
) -> Array: ...
def asin(x: ArrayLike, /) -> Array: ...
def asinh(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -522,22 +524,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ...
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: builtins.bool, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: builtins.bool = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array | tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...

def load(*args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...
Expand Down
27 changes: 20 additions & 7 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,25 +2992,39 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})
output = func(dtype=dtype, device=device)
if isinstance(output, tuple):
for out in output:
self.assertEqual(out.devices(), {device})
else:
self.assertEqual(output.devices(), {device})

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)
output = func(dtype=dtype, device=sharding)
if isinstance(output, tuple):
for out in output:
self.assertEqual(out.sharding, sharding)
else:
self.assertEqual(output.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
Expand Down Expand Up @@ -5984,7 +5998,6 @@ def testWrappedSignaturesMatch(self):
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'linspace': ['device'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
Expand Down

0 comments on commit e0fa3ca

Please sign in to comment.