From 10c5c2db8c88a4349c9af8b3cd65dc2b0ec80649 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 May 2024 00:38:57 +0000 Subject: [PATCH] Finalize jax.numpy array API compliance --- CHANGELOG.md | 4 + jax/_src/basearray.pyi | 7 ++ jax/_src/numpy/_version.py | 15 ++++ jax/_src/numpy/array_methods.py | 19 ++++ jax/_src/numpy/util.py | 67 ++++++++++++++- jax/experimental/array_api/__init__.py | 12 +-- jax/experimental/array_api/_array_methods.py | 45 ---------- .../array_api/_utility_functions.py | 86 ------------------- jax/numpy/__init__.py | 5 ++ 9 files changed, 118 insertions(+), 142 deletions(-) create mode 100644 jax/_src/numpy/_version.py delete mode 100644 jax/experimental/array_api/_array_methods.py delete mode 100644 jax/experimental/array_api/_utility_functions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c7f1e0eabf91..3f1a16e26d11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.29 +* New Functionality + * The {mod}`jax.numpy` module is now compliant with the Python array API + 2023 standard. + * Deprecations * Removed a number of previously-deprecated APIs: * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 5eb4e9e5c0b8..63348e31de6d 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -213,6 +213,13 @@ class Array(abc.ABC): def traceback(self) -> Traceback: ... def unsafe_buffer_pointer(self) -> int: ... + # Array API inclusions + @property + def device(self) -> Sharding: ... + def to_device(self, device: Device | Sharding | None, *, + stream: int | Any | None = None) -> Sharding: ... + def __array_namespace__(self, /, *, + api_version: None | str = None): ... ArrayLike = Union[ Array, # JAX array type diff --git a/jax/_src/numpy/_version.py b/jax/_src/numpy/_version.py new file mode 100644 index 000000000000..776dbf5fcb64 --- /dev/null +++ b/jax/_src/numpy/_version.py @@ -0,0 +1,15 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__array_api_version__ = '2023.12' diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 1d27c4b3aa28..3e2f977c6644 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -45,6 +45,7 @@ from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape from jax._src.util import safe_zip, safe_map +from jax._src.numpy._version import __array_api_version__ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -56,6 +57,21 @@ # operator overloads mainly just forward calls to the corresponding lax_numpy # functions, which can themselves handle instances from any of these classes. +def _array_namespace(arr: ArrayLike, /, *, api_version: None | str = None): + if api_version is not None and api_version != __array_api_version__: + raise ValueError(f"{api_version=!r} is not available; " + f"available versions are: {[__array_api_version__]}") + return jax.numpy + +def _to_device(arr: ArrayLike, device: xc.Device | Sharding | None, *, + stream: int | Any | None = None): + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + return jax.device_put(arr, device) + +def _device(arr: Array) -> Sharding: + """Length of one array element in bytes.""" + return arr.sharding def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. @@ -658,6 +674,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, } _array_methods = { + "__array_namespace__": _array_namespace, "all": reductions.all, "any": reductions.any, "argmax": lax_numpy.argmax, @@ -694,6 +711,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "sum": reductions.sum, "swapaxes": lax_numpy.swapaxes, "take": lax_numpy.take, + "to_device": _to_device, "trace": lax_numpy.trace, "transpose": _transpose, "var": reductions.var, @@ -718,6 +736,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "nbytes": _nbytes, "itemsize": _itemsize, "at": _IndexUpdateHelper, + "device": _device, } def _set_shaped_array_attributes(shaped_array): diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index fb3b7e4e9dc9..57b62afde283 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -17,7 +17,7 @@ from functools import partial import re import textwrap -from typing import Any, Callable, NamedTuple, TypeVar +from typing import Any, Callable, NamedTuple, TypeVar, Tuple import warnings @@ -28,6 +28,8 @@ from jax._src.lax import lax from jax._src.util import safe_zip, safe_map from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc import numpy as np @@ -453,3 +455,66 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: except: is_always_empty = False # can fail with dynamic shapes return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr + +class __array_namespace_info__: + + def __init__(self): + self._capabilities = { + "boolean indexing": True, + "data-dependent shapes": False, + } + + + def _build_dtype_dict(self): + array_api_types = { + "bool", "int8", "int16", + "int32", "uint8", "uint16", + "uint32", "float32", "complex64" + } + if config.enable_x64.value: + array_api_types |= {"int64", "uint64", "float64", "complex128"} + return {category: {t.name: t for t in types if t.name in array_api_types} + for category, types in dtypes._dtype_kinds.items()} + + def default_device(self): + # By default JAX arrays are uncommitted (device=None), meaning that + # JAX is free to choose the most efficient device placement. + return None + + def devices(self): + return api.devices() + + def capabilities(self): + return self._capabilities + + def default_dtypes(self, *, device: xc.Device | Sharding | None = None): + # Array API supported dtypes are device-independent in JAX + del device + default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + return { + dtype_name: dtypes.canonicalize_dtype( + dtypes._default_types.get(kind) + ) for dtype_name, kind in default_dtypes.items() + } + + def dtypes( + self, *, + device: xc.Device | Sharding | None = None, + kind: str | Tuple[str, ...] | None = None): + # Array API supported dtypes are device-independent in JAX + del device + data_types = self._build_dtype_dict() + if kind is None: + out_dict = data_types["numeric"] | data_types["bool"] + elif isinstance(kind, tuple): + out_dict = {} + for _kind in kind: + out_dict |= data_types[_kind] + else: + out_dict = data_types[kind] + return out_dict diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index f7375a80fa8a..a7ec50e328d7 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -36,12 +36,12 @@ from __future__ import annotations -from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ - from jax.experimental.array_api import fft as fft from jax.experimental.array_api import linalg as linalg from jax.numpy import ( + __array_api_version__ as __array_api_version__, + __array_namespace_info__ as __array_namespace_info__, abs as abs, acos as acos, acosh as acosh, @@ -203,11 +203,3 @@ std as std, var as var, ) - -from jax.experimental.array_api._utility_functions import ( - __array_namespace_info__ as __array_namespace_info__, -) - -from jax.experimental.array_api import _array_methods -_array_methods.add_array_object_methods() -del _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py deleted file mode 100644 index 2b071db573a8..000000000000 --- a/jax/experimental/array_api/_array_methods.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any - -import jax -from jax._src.array import ArrayImpl -from jax.experimental.array_api._version import __array_api_version__ -from jax.sharding import Sharding - -from jax._src.lib import xla_extension as xe - - -def _array_namespace(self, /, *, api_version: None | str = None): - if api_version is not None and api_version != __array_api_version__: - raise ValueError(f"{api_version=!r} is not available; " - f"available versions are: {[__array_api_version__]}") - return jax.experimental.array_api - - -def _to_device(self, device: xe.Device | Sharding | None, *, - stream: int | Any | None = None): - if stream is not None: - raise NotImplementedError("stream argument of array.to_device()") - return jax.device_put(self, device) - - -def add_array_object_methods(): - # TODO(jakevdp): set on tracers as well? - setattr(ArrayImpl, "__array_namespace__", _array_namespace) - setattr(ArrayImpl, "to_device", _to_device) - setattr(ArrayImpl, "device", property(lambda self: self.sharding)) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py deleted file mode 100644 index c5dac25fd8c6..000000000000 --- a/jax/experimental/array_api/_utility_functions.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -from typing import Tuple -from jax._src.sharding import Sharding -from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config - -# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api -# deprecation -class __array_namespace_info__: - - def __init__(self): - self._capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, - } - - - def _build_dtype_dict(self): - array_api_types = { - "bool", "int8", "int16", - "int32", "uint8", "uint16", - "uint32", "float32", "complex64" - } - if config.enable_x64.value: - array_api_types |= {"int64", "uint64", "float64", "complex128"} - return {category: {t.name: t for t in types if t.name in array_api_types} - for category, types in _dtypes._dtype_kinds.items()} - - def default_device(self): - # By default JAX arrays are uncommitted (device=None), meaning that - # JAX is free to choose the most efficient device placement. - return None - - def devices(self): - return jax.devices() - - def capabilities(self): - return self._capabilities - - def default_dtypes(self, *, device: xc.Device | Sharding | None = None): - # Array API supported dtypes are device-independent in JAX - del device - default_dtypes = { - "real floating": "f", - "complex floating": "c", - "integral": "i", - "indexing": "i", - } - return { - dtype_name: _dtypes.canonicalize_dtype( - _dtypes._default_types.get(kind) - ) for dtype_name, kind in default_dtypes.items() - } - - def dtypes( - self, *, - device: xc.Device | Sharding | None = None, - kind: str | Tuple[str, ...] | None = None): - # Array API supported dtypes are device-independent in JAX - del device - data_types = self._build_dtype_dict() - if kind is None: - out_dict = data_types["numeric"] | data_types["bool"] - elif isinstance(kind, tuple): - out_dict = {} - for _kind in kind: - out_dict |= data_types[_kind] - else: - out_dict = data_types[kind] - return out_dict diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 1b9a990f3a0d..c98624d05301 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -445,6 +445,11 @@ from jax._src.numpy.vectorize import vectorize as vectorize +from jax._src.numpy._version import __array_api_version__ +from jax._src.numpy.util import ( + __array_namespace_info__ as __array_namespace_info__, +) + # Dynamically register numpy-style methods on JAX arrays. from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods()