Skip to content

Commit

Permalink
Update the deprecation message of backend and device argument of …
Browse files Browse the repository at this point in the history
…`jit` to be more actionable.

PiperOrigin-RevId: 637899890
  • Loading branch information
yashk2810 authored and jax authors committed May 28, 2024
1 parent de28ee6 commit ff3db9b
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,9 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,

if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use a '
'`jax.sharding.Mesh` context manager or device_put the arguments '
'before passing them to `jit`. Please see '
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
'for more information.', DeprecationWarning)
'backend and device argument on jit is deprecated. You can use'
' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to'
' the jitted function to get the same behavior.', DeprecationWarning)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
Expand Down

0 comments on commit ff3db9b

Please sign in to comment.