Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[memories] Transfer to pinned_host fast path in async_serialize #22114

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gspschmid
Copy link
Contributor

Adds a fast path to jax.experimental.array_serialization.serialization.async_serialize that avoids XLA's regular device-to-host transfer and instead uses a single device-to-pinned-host transfer per _write_array(arr) invocation. This allows us to achieve much closer to ideal transfer bandwidths in practice. For comparison, the existing approach stages copies through a fixed size intermediate 128MB-buffer and requires sizeof(arr)/128MB alternations between D2H and H2H copies.

Note that the np.array(data, copy=False) is not strictly necessary as the tensorstore invocation t.write(...) immediately performs the C-API equivalent of np.array(data, copy=None). We expect all of these to be zero-copy, hence explicitly calling np.array(data, copy=False) provides some extra safety, since it would fail if jax.Array's implementation changed and no longer permitted zero-copying its private numpy array _value. Alas the latter check is not fool-proof: for example, prior to XLA#14089 the construction of the jax.Array from the device buffer also forced a copy.

Depends on XLA#14087 XLA#14088 XLA#14089 XLA#14090

# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(shard.device,
memory_kind="pinned_host")
data = jax.jit(lambda x: x, out_shardings=sharding)(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use jax.device_put instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately doesn't work yet with memory_type="pinned_host" and based on our attempts we believe this will require a longer tail of fixes to XLA. @jaro-sevcik can elaborate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should! We have tests in memories_test.py that shows it does work. It would be nice to not run a xla computation for a transfer of this kind.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need couple more patches in XLA for non-jitted device_put to work: first openxla/xla#14089 (already submitted for review), and then the last commit from
https://github.com/jaro-sevcik/xla/tree/device-put-memory-kind-sharding (will submit once openxla/xla#14089 lands).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched over to using device_put in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness, this PR now depends on the XLA PR openxla/xla#14268 (that enables copying buffers to a different memory space).

memory_kind="pinned_host")
data = jax.jit(lambda x: x, out_shardings=sharding)(data)
# Allow other transfers to be scheduled simultaneously
await asyncio.sleep(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this sleep? await should schedule it concurrently anyways right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Deferred to a separate commit)

@yashk2810
Copy link
Member

Do you have some benchmarks where this is super fast and helpful? (something that you ran locally or in your runs?)

@gspschmid gspschmid force-pushed the gschmid/async_serialize-transfer-pinned branch 2 times, most recently from 976610e to 4488182 Compare June 27, 2024 10:46
@gspschmid
Copy link
Contributor Author

@yashk2810

Do you have some benchmarks where this is super fast and helpful? (something that you ran locally or in your runs?)

Here's a self-contained example that doesn't quite behave like the E2E workload mentioned before, but illustrates the effects and is a good candidate for profiling: https://gist.github.com/gspschmid/52a1062916c7030a513b0581bd56c5be

The first improvement corresponds to this PR (along with its XLA dependencies), the second improvement corresponds to #22169. Note that after applying the first improvement other overheads in tensorstore's t.write(data) begin to dominate. Attached below are some screenshots of nsys profiles corresponding to the last iteration for each variant.

Baseline:
image

device-to-pinned-host transfer:
image

device-to-pinned-host transfer + overlap shard transfers:
image

@gspschmid gspschmid force-pushed the gschmid/async_serialize-transfer-pinned branch from 4488182 to 62940cf Compare July 1, 2024 09:15
@gspschmid
Copy link
Contributor Author

@yashk2810 Not sure what CI you run for JAX contributions, but now that the remaining XLA PRs are in (openxla/xla#14089 and openxla/xla#14268) this should be ready to test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants