Skip to content

Commit

Permalink
Replace XLACompatibleSharding with jax.sharding.Sharding since th…
Browse files Browse the repository at this point in the history
…e former is deprecated and will be removed in the future.

PiperOrigin-RevId: 640572753
  • Loading branch information
yashk2810 authored and Flax Authors committed Jun 5, 2024
1 parent 617a6cd commit 21231af
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flax/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def jit(
if the sharding cannot be inferred.
The valid resource assignment specifications are:
- :py:class:`XLACompatibleSharding`, which will decide how the value
- :py:class:`Sharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
Expand Down

0 comments on commit 21231af

Please sign in to comment.