Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax.experimental.array_api to v2023.12 API #20200

Open
12 of 13 tasks
Micky774 opened this issue Mar 12, 2024 · 2 comments
Open
12 of 13 tasks

Update jax.experimental.array_api to v2023.12 API #20200

Micky774 opened this issue Mar 12, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@Micky774
Copy link
Collaborator

Micky774 commented Mar 12, 2024

This issue tracks the changes necessary to adopt the v2023.12 Array API. This was originally mentioned in #18353. Note that there may be some specifications that we already satisfy, however the vast majority of these will need alterations.

API Updates

New API

Breaking Changes

For specific details on what has changed, look on their specification pages for "Changed in version 2023.12 ..."

Common Utilities Refactor

  • Device placement + copy semantics
    • Initial draft can be found here
    • Get Device from Device | Sharding
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

Thanks for adding this! One overall note: our eventual goal is to remove jax.experimental.array_api and just use the jax.numpy namespace directly. So as much as possible, we should aim for API functions in jax.experimental.array_api to not have any additional logic beyond just calling the jax.numpy counterpart.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

A note on default_device (cf. a discussion with @yashk2810). JAX doesn't really have a concept of "default device" in the way that the Array API envisions it. By default, arrays are created uncommitted, so the only way to write a consistent default_device function would be for it to return "uncommitted". Currently most (if not all) functions that accept a device parameter have a default value None, so our default_device function should probably look like this:

def default_device():
  return None

and be documented appropriately. That's the only way for, e.g. jax.device_put(x, device=jnp.default_device()) to have the same behavior as jax.device_put(x), which seems like a sensible requirement for the concept of a default!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

2 participants