Skip to content

Commit

Permalink
Add __array_namespace_info__ and corresponding utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 18, 2024
1 parent aaeeaf5 commit b0eb8b9
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)
Expand Down
91 changes: 91 additions & 0 deletions jax/experimental/array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import jax
from typing import List, Tuple
from jax._src.sharding import Sharding
from jax._src.xla_bridge import backends
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config


def all(x, /, *, axis=None, keepdims=False):
Expand All @@ -23,3 +29,88 @@ def all(x, /, *, axis=None, keepdims=False):
def any(x, /, *, axis=None, keepdims=False):
"""Tests whether any input array element evaluates to True along a specified axis."""
return jax.numpy.any(x, axis=axis, keepdims=keepdims)

class __array_namespace_info__:

def __init__(self):
self._default_dtypes = self._build_default_dtype_dict()
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}
self._data_types = self._build_dtype_dict()

def _build_default_dtype_dict(self):
default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
for dtype_name, kind in default_dtypes.items():
dtype = _dtypes._default_types.get(kind)
dtype = _dtypes.canonicalize_dtype(dtype)
default_dtypes[dtype_name] = dtype
return default_dtypes

def _build_dtype_dict(self):
data_types = {
"signed integer": ["int8", "int16", "int32", "int64"],
"unsigned integer": ["uint8", "uint16", "uint32", "uint64"],
"real floating": ["float32", "float64"],
"complex floating": ["complex64", "complex128"],
}
if not config.enable_x64.value:
for category in data_types:
data_types[category] = data_types[category][:-1]

data_types["bool"] = ["bool"]

for category in data_types:
_dtype_dict = {}
for name in data_types[category]:
_dtype_dict[name] = _dtypes.dtype(name)
data_types[category] = _dtype_dict
data_types["integral"] = (
data_types["signed integer"] | data_types["unsigned integer"]
)
data_types["numeric"] = (
data_types["integral"]
| data_types["real floating"]
| data_types["complex floating"]
)
return data_types

def default_device(self):
# Note that since arrays are create uncommitted they technically do not
# have a default device in the classical sense. Functions that accept a
# device parameter generally have a default value of None anyways, so
# to reconcile those two facts we simply return None as our default device
# so that callers e.g. use jax.device_put(x, jnp.default_device()) to
# achieve equivalent results to jax.device_put(x). See gh-20200 for
# details.
return None

def devices(self):
available_devices: List[xc.Device] = []
for _backend in backends():
available_devices.extend(jax.devices(_backend))
return available_devices

def capabilities(self):
return self._capabilities

def default_dtypes(self):
return self._default_dtypes

def dtypes(self, *, device: xc.Device | Sharding | None = None, kind: str | Tuple[str, ...] | None = None):
# Array API supported dtypes are device-independent in JAX
if kind is None:
out_dict = self._data_types["numeric"] | self._data_types["bool"]
elif isinstance(kind, tuple):
out_dict = {}
for _kind in kind:
out_dict |= self._data_types[_kind]
else:
out_dict = self._data_types[kind]
return out_dict
86 changes: 86 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from absl.testing import absltest
import jax
from jax import config
import jax.numpy as jnp
from jax.experimental import array_api
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax._src import test_util as jtu
from absl.testing import parameterized

config.parse_flags_with_absl()

Expand Down Expand Up @@ -233,5 +237,87 @@ def test_array_namespace_method(self):
self.assertIs(x.__array_namespace__(), array_api)


class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

info = array_api.__array_namespace_info__()

def setUp(self):
super().setUp()
self._boolean = self.build_target_dict(["bool"])
self._signed = self.build_target_dict(["int8", "int16", "int32"])
self._unsigned = self.build_target_dict(["uint8", "uint16", "uint32"])
self._floating = self.build_target_dict(["float32"])
self._complex = self.build_target_dict(["complex64"])
if config.jax_enable_x64:
self._signed["int64"] = jnp.dtype("int64")
self._unsigned["uint64"] = jnp.dtype("uint64")
self._floating["float64"] = jnp.dtype("float64")
self._complex["complex128"] = jnp.dtype("complex128")

def build_target_dict(self, dtypes):
out = {}
for name in dtypes:
out[name] = jnp.dtype(name)
return out

def test_capabilities_info(self):
capabilities = self.info.capabilities()
assert capabilities["boolean indexing"]
assert not capabilities["data-dependent shapes"]

def test_default_device(self):
assert self.info.default_device() is None

def test_devices_info(self):
devices = self.info.devices()
x = array_api.arange(5)
# Sinfoty check that the outputs of __array_namespace_info__.devices() can
# be directly passed to Array API creation functions
for device in devices:
self.assertArraysEqual(x, array_api.arange(x.shape[0], device=device))

def test_default_dtypes_info(self):
_default_dtypes = {
"real floating": "f",
"complex floating": "c",
"integral": "i",
"indexing": "i",
}
for dtype_name, kind in _default_dtypes.items():
dtype = _default_types.get(kind)
dtype = canonicalize_dtype(dtype)
_default_dtypes[dtype_name] = dtype
assert self.info.default_dtypes() == _default_dtypes

@parameterized.parameters("bool", "signed integer", "real floating",
"complex floating", "integral", "numeric", None,
(("real floating", "complex floating"),),
(("integral", "signed integer"),),
(("integral", "bool"),))
def test_dtypes_info(self, kind):

info_dict = self.info.dtypes(kind=kind)
control = {
"bool":self._boolean,
"signed integer":self._signed,
"unsigned integer":self._unsigned,
"real floating":self._floating,
"complex floating":self._complex,
}
control["integral"] = self._signed | self._unsigned
control["numeric"] = (
self._signed | self._unsigned | self._floating | self._complex
)
target_dict = {}
if kind is None:
target_dict = control["numeric"] | self._boolean
elif isinstance(kind, tuple):
target_dict = {}
for _kind in kind:
target_dict |= control[_kind]
else:
target_dict = control[kind]
assert info_dict == target_dict

if __name__ == '__main__':
absltest.main()

0 comments on commit b0eb8b9

Please sign in to comment.