-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[memories] Transfer to pinned_host fast path in async_serialize #22114
base: main
Are you sure you want to change the base?
[memories] Transfer to pinned_host fast path in async_serialize #22114
Conversation
# If available, transfer to pinned host memory | ||
sharding = jax.sharding.SingleDeviceSharding(shard.device, | ||
memory_kind="pinned_host") | ||
data = jax.jit(lambda x: x, out_shardings=sharding)(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use jax.device_put
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately doesn't work yet with memory_type="pinned_host"
and based on our attempts we believe this will require a longer tail of fixes to XLA. @jaro-sevcik can elaborate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should! We have tests in memories_test.py that shows it does work. It would be nice to not run a xla computation for a transfer of this kind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need couple more patches in XLA for non-jitted device_put to work: first openxla/xla#14089 (already submitted for review), and then the last commit from
https://github.com/jaro-sevcik/xla/tree/device-put-memory-kind-sharding (will submit once openxla/xla#14089 lands).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched over to using device_put
in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For completeness, this PR now depends on the XLA PR openxla/xla#14268 (that enables copying buffers to a different memory space).
memory_kind="pinned_host") | ||
data = jax.jit(lambda x: x, out_shardings=sharding)(data) | ||
# Allow other transfers to be scheduled simultaneously | ||
await asyncio.sleep(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this sleep? await
should schedule it concurrently anyways right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Deferred to a separate commit)
Do you have some benchmarks where this is super fast and helpful? (something that you ran locally or in your runs?) |
976610e
to
4488182
Compare
Here's a self-contained example that doesn't quite behave like the E2E workload mentioned before, but illustrates the effects and is a good candidate for profiling: https://gist.github.com/gspschmid/52a1062916c7030a513b0581bd56c5be The first improvement corresponds to this PR (along with its XLA dependencies), the second improvement corresponds to #22169. Note that after applying the first improvement other overheads in tensorstore's |
4488182
to
62940cf
Compare
@yashk2810 Not sure what CI you run for JAX contributions, but now that the remaining XLA PRs are in (openxla/xla#14089 and openxla/xla#14268) this should be ready to test. |
Adds a fast path to
jax.experimental.array_serialization.serialization.async_serialize
that avoids XLA's regular device-to-host transfer and instead uses a single device-to-pinned-host transfer per_write_array(arr)
invocation. This allows us to achieve much closer to ideal transfer bandwidths in practice. For comparison, the existing approach stages copies through a fixed size intermediate 128MB-buffer and requiressizeof(arr)/128MB
alternations between D2H and H2H copies.Note that the
np.array(data, copy=False)
is not strictly necessary as the tensorstore invocationt.write(...)
immediately performs the C-API equivalent ofnp.array(data, copy=None)
. We expect all of these to be zero-copy, hence explicitly callingnp.array(data, copy=False)
provides some extra safety, since it would fail if jax.Array's implementation changed and no longer permitted zero-copying its private numpy array_value
. Alas the latter check is not fool-proof: for example, prior to XLA#14089 the construction of the jax.Array from the device buffer also forced a copy.Depends on XLA#14087 XLA#14088 XLA#14089 XLA#14090