Skip to content

Commit

Permalink
Added device to jnp.arange, jnp.eye and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 26, 2024
1 parent 66287cd commit 78ee8a5
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 18 deletions.
40 changes: 37 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3603,8 +3603,25 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))


@util.implements(np.eye)
@util.implements(np.eye, extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
""")
def eye(N: DimSize, M: DimSize | None = None,
k: int | ArrayLike = 0,
dtype: DTypeLike | None = None,
*, device: xc.Device | Sharding | None = None) -> Array:
# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _eye(N, M=M, k=k, dtype=dtype)
if device is not None:
return jax.device_put(output, device=device)
return output


def _eye(N: DimSize, M: DimSize | None = None,
k: int | ArrayLike = 0,
dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "eye")
Expand All @@ -3629,7 +3646,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
return eye(n, dtype=dtype)


@util.implements(np.arange,lax_description= """
@util.implements(np.arange, lax_description= """
.. note::
Using ``arange`` with the ``step`` argument can lead to precision errors,
Expand All @@ -3638,8 +3655,25 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
To avoid precision errors, consider using an expression like
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
and then convert it to the desired lower precision.
""")
""", extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
"""
)
def arange(start: DimSize, stop: DimSize | None = None,
step: DimSize | None = None, dtype: DTypeLike | None = None,
*, device: xc.Device | Sharding | None = None) -> Array:
# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _arange(start, stop=stop, step=step, dtype=dtype)
if device is not None:
return jax.device_put(output, device=device)
return output


def _arange(start: DimSize, stop: DimSize | None = None,
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "arange")
if not config.dynamic_shapes.value:
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
add as add,
all as all,
any as any,
arange as arange,
argmax as argmax,
argmin as argmin,
argsort as argsort,
Expand Down Expand Up @@ -83,6 +84,7 @@
exp as exp,
expand_dims as expand_dims,
expm1 as expm1,
eye as eye,
flip as flip,
float32 as float32,
float64 as float64,
Expand Down Expand Up @@ -183,9 +185,7 @@
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
eye as eye,
linspace as linspace,
)

Expand Down
6 changes: 0 additions & 6 deletions jax/experimental/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@
import jax
import jax.numpy as jnp

# TODO(micky774): Deprecate after adding device argument to jax.numpy functions
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)

def asarray(obj, /, *, dtype=None, device=None, copy=None):
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)

def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)

def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
6 changes: 4 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def arange(
start: DimSize,
stop: DimSize | None = ...,
step: DimSize | None = ...,
dtype: DTypeLike | None = ...,
dtype: DTypeLike | None = ..., *,
device: _Device | _Sharding | None = ...,
) -> Array: ...
def arccos(x: ArrayLike, /) -> Array: ...
def arccosh(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -353,7 +354,8 @@ def expm1(x: ArrayLike, /) -> Array: ...
def extract(condition: ArrayLike, arr: ArrayLike, *,
size: int | None = None, fill_value: ArrayLike = 0) -> Array: ...
def eye(N: DimSize, M: DimSize | None = ..., k: int | ArrayLike = ...,
dtype: DTypeLike | None = ...) -> Array: ...
dtype: DTypeLike | None = ..., *,
device: _Device | _Sharding | None = ...) -> Array: ...
def fabs(x: ArrayLike, /) -> Array: ...
finfo = _dtypes.finfo
def fix(x: ArrayLike, out: None = ...) -> Array: ...
Expand Down
43 changes: 38 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,30 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
Expand Down Expand Up @@ -4796,10 +4820,19 @@ def testArangeJit(self):
expected = jtu.with_jax_dtype_defaults(np.arange)(5)
self.assertAllClose(ans, expected)

@jtu.sample_product(args=[(5,), (0, 5)])
def testArangeJaxpr(self, args):
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))()
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
@jtu.sample_product(
args=[(5,), (0, 5)],
specify_device=[True, False],
)
def testArangeJaxpr(self, args, specify_device):
device = jax.devices()[-1] if specify_device else None
kwargs = {"device": device}
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))()
# We have 2 statements in jaxpr:
# [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)],
# a:i32[5] = device_put[devices=[None] srcs=[None]] b]
num_eqs = 2 if device is not None else 1
self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)

def testIssue830(self):
Expand Down Expand Up @@ -5941,7 +5974,7 @@ def testWrappedSignaturesMatch(self):
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
'eye': ['device', 'order', 'like'],
'eye': ['order', 'like'],
'hstack': ['casting'],
'identity': ['like'],
'isin': ['kind'],
Expand Down

0 comments on commit 78ee8a5

Please sign in to comment.