Skip to content

Commit

Permalink
Finalize jax.numpy array API compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed May 21, 2024
1 parent d33a568 commit 10c5c2d
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 142 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.29

* New Functionality
* The {mod}`jax.numpy` module is now compliant with the Python array API
2023 standard.

* Deprecations
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ class Array(abc.ABC):
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...

# Array API inclusions
@property
def device(self) -> Sharding: ...
def to_device(self, device: Device | Sharding | None, *,
stream: int | Any | None = None) -> Sharding: ...
def __array_namespace__(self, /, *,
api_version: None | str = None): ...

ArrayLike = Union[
Array, # JAX array type
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/numpy/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__array_api_version__ = '2023.12'
19 changes: 19 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape
from jax._src.util import safe_zip, safe_map
from jax._src.numpy._version import __array_api_version__

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand All @@ -56,6 +57,21 @@
# operator overloads mainly just forward calls to the corresponding lax_numpy
# functions, which can themselves handle instances from any of these classes.

def _array_namespace(arr: ArrayLike, /, *, api_version: None | str = None):
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
return jax.numpy

def _to_device(arr: ArrayLike, device: xc.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return jax.device_put(arr, device)

def _device(arr: Array) -> Sharding:
"""Length of one array element in bytes."""
return arr.sharding

def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array:
"""Copy the array and cast to a specified dtype.
Expand Down Expand Up @@ -658,6 +674,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
}

_array_methods = {
"__array_namespace__": _array_namespace,
"all": reductions.all,
"any": reductions.any,
"argmax": lax_numpy.argmax,
Expand Down Expand Up @@ -694,6 +711,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"sum": reductions.sum,
"swapaxes": lax_numpy.swapaxes,
"take": lax_numpy.take,
"to_device": _to_device,
"trace": lax_numpy.trace,
"transpose": _transpose,
"var": reductions.var,
Expand All @@ -718,6 +736,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"nbytes": _nbytes,
"itemsize": _itemsize,
"at": _IndexUpdateHelper,
"device": _device,
}

def _set_shaped_array_attributes(shaped_array):
Expand Down
67 changes: 66 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from functools import partial
import re
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar
from typing import Any, Callable, NamedTuple, TypeVar, Tuple

import warnings

Expand All @@ -28,6 +28,8 @@
from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc

import numpy as np

Expand Down Expand Up @@ -453,3 +455,66 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr

class __array_namespace_info__:

def __init__(self):
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}


def _build_dtype_dict(self):
array_api_types = {
"bool", "int8", "int16",
"int32", "uint8", "uint16",
"uint32", "float32", "complex64"
}
if config.enable_x64.value:
array_api_types |= {"int64", "uint64", "float64", "complex128"}
return {category: {t.name: t for t in types if t.name in array_api_types}
for category, types in dtypes._dtype_kinds.items()}

def default_device(self):
# By default JAX arrays are uncommitted (device=None), meaning that
# JAX is free to choose the most efficient device placement.
return None

def devices(self):
return api.devices()

def capabilities(self):
return self._capabilities

def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
# Array API supported dtypes are device-independent in JAX
del device
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
return {
dtype_name: dtypes.canonicalize_dtype(
dtypes._default_types.get(kind)
) for dtype_name, kind in default_dtypes.items()
}

def dtypes(
self, *,
device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
del device
data_types = self._build_dtype_dict()
if kind is None:
out_dict = data_types["numeric"] | data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= data_types[_kind]
else:
out_dict = data_types[kind]
return out_dict
12 changes: 2 additions & 10 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@

from __future__ import annotations

from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import fft as fft
from jax.experimental.array_api import linalg as linalg

from jax.numpy import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,
abs as abs,
acos as acos,
acosh as acosh,
Expand Down Expand Up @@ -203,11 +203,3 @@
std as std,
var as var,
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
)

from jax.experimental.array_api import _array_methods
_array_methods.add_array_object_methods()
del _array_methods
45 changes: 0 additions & 45 deletions jax/experimental/array_api/_array_methods.py

This file was deleted.

86 changes: 0 additions & 86 deletions jax/experimental/array_api/_utility_functions.py

This file was deleted.

5 changes: 5 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@

from jax._src.numpy.vectorize import vectorize as vectorize

from jax._src.numpy._version import __array_api_version__
from jax._src.numpy.util import (
__array_namespace_info__ as __array_namespace_info__,
)

# Dynamically register numpy-style methods on JAX arrays.
from jax._src.numpy.array_methods import register_jax_array_methods
register_jax_array_methods()
Expand Down

0 comments on commit 10c5c2d

Please sign in to comment.