Skip to content

Commit

Permalink
Add __array_namespace_info__ and corresponding utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 18, 2024
1 parent aaeeaf5 commit 7f4f7ef
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)
Expand Down
84 changes: 83 additions & 1 deletion jax/experimental/array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from __future__ import annotations

import jax
from typing import Tuple
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config

def all(x, /, *, axis=None, keepdims=False):
"""Tests whether all input array elements evaluate to True along a specified axis."""
Expand All @@ -23,3 +28,80 @@ def all(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False):
"""Tests whether any input array element evaluates to True along a specified axis."""
return jax.numpy.any(x, axis=axis, keepdims=keepdims)

class __array_namespace_info__:

def __init__(self):
self._default_dtypes = self._build_default_dtype_dict()
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}
self._data_types = self._build_dtype_dict()

def _build_default_dtype_dict(self):
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
for dtype_name, kind in default_dtypes.items():
dtype = _dtypes._default_types.get(kind)
dtype = _dtypes.canonicalize_dtype(dtype)
default_dtypes[dtype_name] = dtype
return default_dtypes

def _build_dtype_dict(self):
data_types = {
"signed integer": ["int8", "int16", "int32", "int64"],
"unsigned integer": ["uint8", "uint16", "uint32", "uint64"],
"real floating": ["float32", "float64"],
"complex floating": ["complex64", "complex128"],
}
if not config.enable_x64.value:
for category in data_types:
data_types[category] = data_types[category][:-1]

data_types["bool"] = ["bool"]

for category in data_types:
_dtype_dict = {}
for name in data_types[category]:
_dtype_dict[name] = _dtypes.dtype(name)
data_types[category] = _dtype_dict
data_types["integral"] = (
data_types["signed integer"] | data_types["unsigned integer"]
)
data_types["numeric"] = (
data_types["integral"]
| data_types["real floating"]
| data_types["complex floating"]
)
return data_types

def default_device(self):
# By default JAX arrays are uncommitted (device=None), meaning that
# JAX is free to choose the most efficient device placement.
return None

def devices(self):
return jax.devices()

def capabilities(self):
return self._capabilities

def default_dtypes(self):
return self._default_dtypes

def dtypes(self, *, device: xc.Device | Sharding | None = None, kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
if kind is None:
out_dict = self._data_types["numeric"] | self._data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= self._data_types[_kind]
else:
out_dict = self._data_types[kind]
return out_dict
82 changes: 80 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from types import ModuleType

from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
from jax import config
import jax.numpy as jnp
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api

config.parse_flags_with_absl()
Expand Down Expand Up @@ -232,6 +234,82 @@ def test_array_namespace_method(self):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)

class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

info = array_api.__array_namespace_info__()

def setUp(self):
super().setUp()
self._boolean = self.build_target_dict(["bool"])
self._signed = self.build_target_dict(["int8", "int16", "int32"])
self._unsigned = self.build_target_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_target_dict(["float32"])
self._complex = self.build_target_dict(["complex64"])
if config.enable_x64.value:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")

def build_target_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out

def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]

def test_default_device_info(self):
assert self.info.default_device() is None

def test_devices_info(self):
assert self.info.devices() == jax.devices()

def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
for dtype_name, kind in _default_dtypes.items():
dtype = _default_types.get(kind)
dtype = canonicalize_dtype(dtype)
_default_dtypes[dtype_name] = dtype
assert self.info.default_dtypes() == _default_dtypes

@parameterized.parameters("bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),))
def test_dtypes_info(self, kind):

info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
}
control["integral"] = self._signed | self._unsigned
control["numeric"] = (
self._signed | self._unsigned | self._floating | self._complex
)
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict

if __name__ == '__main__':
absltest.main()

0 comments on commit 7f4f7ef

Please sign in to comment.