Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 12, 2024
1 parent d0eae05 commit dfd35ab
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations


import builtins
import functools
from typing import NamedTuple
from typing import Any, NamedTuple
import jax
import jax.numpy as jnp
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src.api import device_put


from jax.experimental.array_api._dtypes import (
Expand Down Expand Up @@ -124,9 +129,22 @@ def _promote_types(t1, t2):
raise ValueError("No promotion path for {t1} & {t2}")


def astype(x, dtype, /, *, copy=True):
return jnp.array(x, dtype=dtype, copy=copy)

def astype(x, dtype, /, *, copy: builtins.bool | None = True, device: xc.Device | Sharding | Any | None = None):
arr = jnp.array(x, dtype=dtype)
src_device = arr.devices().pop()
# TODO(micky774): refactor into a common utility with _place_array in gh-20175
if device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(src_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(arr, device)
if copy:
return jnp.array(arr, copy=True)
return arr

def can_cast(from_, to, /):
if isinstance(from_, jax.Array):
Expand Down

0 comments on commit dfd35ab

Please sign in to comment.