Skip to content

Commit

Permalink
Merge pull request #20294 from Micky774:array_namespace_info
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623877931
  • Loading branch information
jax authors committed Apr 11, 2024
2 parents d9d11a3 + e6508a4 commit 301c351
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)
Expand Down
70 changes: 69 additions & 1 deletion jax/experimental/array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from __future__ import annotations

import jax
from typing import Tuple
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config

def all(x, /, *, axis=None, keepdims=False):
"""Tests whether all input array elements evaluate to True along a specified axis."""
Expand All @@ -23,3 +28,66 @@ def all(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False):
"""Tests whether any input array element evaluates to True along a specified axis."""
return jax.numpy.any(x, axis=axis, keepdims=keepdims)

class __array_namespace_info__:

def __init__(self):
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}


def _build_dtype_dict(self):
array_api_types = {
"bool", "int8", "int16",
"int32", "uint8", "uint16",
"uint32", "float32", "complex64"
}
if config.enable_x64.value:
array_api_types |= {"int64", "uint64", "float64", "complex128"}
return {category: {t.name: t for t in types if t.name in array_api_types}
for category, types in _dtypes._dtype_kinds.items()}

def default_device(self):
# By default JAX arrays are uncommitted (device=None), meaning that
# JAX is free to choose the most efficient device placement.
return None

def devices(self):
return jax.devices()

def capabilities(self):
return self._capabilities

def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
# Array API supported dtypes are device-independent in JAX
del device
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
return {
dtype_name: _dtypes.canonicalize_dtype(
_dtypes._default_types.get(kind)
) for dtype_name, kind in default_dtypes.items()
}

def dtypes(
self, *,
device: xc.Device | Sharding | None = None,
kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
del device
data_types = self._build_dtype_dict()
if kind is None:
out_dict = data_types["numeric"] | data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= data_types[_kind]
else:
out_dict = data_types[kind]
return out_dict
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ py_test(
deps = [
"//jax",
"//jax:experimental_array_api",
"//jax:test_util",
] + py_deps("absl/testing"),
)

Expand Down
86 changes: 84 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from types import ModuleType

from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
from jax import config
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api

config.parse_flags_with_absl()
Expand Down Expand Up @@ -233,6 +235,86 @@ def test_array_namespace_method(self):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)

class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

info = array_api.__array_namespace_info__()

def setUp(self):
super().setUp()
self._boolean = self.build_dtype_dict(["bool"])
self._signed = self.build_dtype_dict(["int8", "int16", "int32"])
self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_dtype_dict(["float32"])
self._complex = self.build_dtype_dict(["complex64"])
if config.enable_x64.value:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")
self._integral = self._signed | self._unsigned
self._numeric = (
self._signed | self._unsigned | self._floating | self._complex
)
def build_dtype_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out

def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]

def test_default_device_info(self):
assert self.info.default_device() is None

def test_devices_info(self):
assert self.info.devices() == jax.devices()

def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
target_dict = {
dtype_name: canonicalize_dtype(
_default_types.get(kind)
) for dtype_name, kind in _default_dtypes.items()
}
assert self.info.default_dtypes() == target_dict

@parameterized.parameters(
"bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),),
)
def test_dtypes_info(self, kind):

info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
"integral": self._integral,
"numeric": self._numeric
}
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict

class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""
Expand Down

0 comments on commit 301c351

Please sign in to comment.