[memories] Overlap shard transfers in async_serialize #22169
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Builds upon #22114, which introduces a fast-path for async checkpoint saving in which each array shard is copied through a single device-to-pinned-host transfer. So far, all of these transfers are serialized.
The present PR allows overlapping the transfers of a single array's shards. This is achieved by inserting
asyncio.sleep(0)
right after the transfer has been started, thereby permitting the_write_array(shard)
coroutine to yield.Why do we need the sleep? Python's
await
doesn't actually guarantee that control will be yielded to the eventloop so that other coroutines can be scheduled. Python ties coroutine semantics to that of generators. As I understand it, the only thing that causes control to pass to the eventloop (and thus allows other coroutines to be scheduled) is ayield
. What we therefore need to achieve overlap in our situation, is something likeThis is, in fact, what
asyncio.sleep(0)
does.An alternative, more roundabout way to achieve overlap would be to split
_write_array(shard)
into a first phase that initiates all the transfers, and provides the resulting host-resident arrays to a second phase, in which we invoket.write(data)
as before.cc @yashk2810