Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __array_namespace_info__ and corresponding utilities #20294

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)
Expand Down
70 changes: 69 additions & 1 deletion jax/experimental/array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from __future__ import annotations

import jax
from typing import Tuple
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config

def all(x, /, *, axis=None, keepdims=False):
"""Tests whether all input array elements evaluate to True along a specified axis."""
Expand All @@ -23,3 +28,66 @@ def all(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False):
"""Tests whether any input array element evaluates to True along a specified axis."""
return jax.numpy.any(x, axis=axis, keepdims=keepdims)

class __array_namespace_info__:

def __init__(self):
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}


def _build_dtype_dict(self):
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
array_api_types = {
"bool", "int8", "int16",
"int32", "uint8", "uint16",
"uint32", "float32", "complex64"
}
if config.enable_x64.value:
array_api_types |= {"int64", "uint64", "float64", "complex128"}
return {category: {t.name: t for t in types if t.name in array_api_types}
for category, types in _dtypes._dtype_kinds.items()}

def default_device(self):
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
# By default JAX arrays are uncommitted (device=None), meaning that
# JAX is free to choose the most efficient device placement.
return None

def devices(self):
return jax.devices()

def capabilities(self):
return self._capabilities

def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
# Array API supported dtypes are device-independent in JAX
del device
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
return {
dtype_name: _dtypes.canonicalize_dtype(
_dtypes._default_types.get(kind)
) for dtype_name, kind in default_dtypes.items()
}

def dtypes(
self, *,
device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
del device
data_types = self._build_dtype_dict()
if kind is None:
out_dict = data_types["numeric"] | data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= data_types[_kind]
else:
out_dict = data_types[kind]
return out_dict
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ py_test(
deps = [
"//jax",
"//jax:experimental_array_api",
"//jax:test_util",
] + py_deps("absl/testing"),
)

Expand Down
86 changes: 84 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from types import ModuleType

from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
from jax import config
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api

config.parse_flags_with_absl()
Expand Down Expand Up @@ -232,6 +234,86 @@ def test_array_namespace_method(self):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)

class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

info = array_api.__array_namespace_info__()

def setUp(self):
super().setUp()
self._boolean = self.build_dtype_dict(["bool"])
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_dtype_dict(["float32"])
self._complex = self.build_dtype_dict(["complex64"])
if config.enable_x64.value:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")
self._integral = self._signed | self._unsigned
self._numeric = (
self._signed | self._unsigned | self._floating | self._complex
)
def build_dtype_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out

def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]

def test_default_device_info(self):
assert self.info.default_device() is None

def test_devices_info(self):
assert self.info.devices() == jax.devices()

def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
target_dict = {
dtype_name: canonicalize_dtype(
_default_types.get(kind)
) for dtype_name, kind in _default_dtypes.items()
}
assert self.info.default_dtypes() == target_dict

@parameterized.parameters(
"bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),),
)
def test_dtypes_info(self, kind):

info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
"integral": self._integral,
"numeric": self._numeric
}
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict

if __name__ == '__main__':
absltest.main()