Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel ids for collectives are not unique when custom sharding is inlined in XLA #14600

Open
Tixxx opened this issue Jul 8, 2024 · 13 comments

Comments

@Tixxx
Copy link
Contributor

Tixxx commented Jul 8, 2024

We are seeing a compilation crash with a custom partitioning defined by the user. I'm attaching the repro script and instructions to repro. The error happens when running with TransformerEngine in our jax toolbox docker container, although I'd think this would happen for any inlined custom sharding.
My repro needs an 8-gpu machine to run.
here's the smalle repro script(https://github.com/Tixxx/fileshare/blob/main/scan_unbound_nvbug_with_te_refactor_v2.py).
Instructions:
docker run --gpus all -it ghcr.io/nvidia/jax:pax-2024-06-18
python scan_unbound_nvbug_with_te_refactor_v2.py

The error is
"INTERNAL: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:2494) first->opcode() == instr->opcode() channel 1 is used for different types of channel instructions"

Apparently, spmd paritioner assigns the same channel id to 2 different collectives.
When the spmd partitioner tries to inline the custom partitioning call, it assigns channel ids to collectives in that computation incrementally, but when it tries to create collectives for sharded instruction in the main computation, the state of the increment is lost and the channels starts from the original value again. The state is passed around in numerous places when creating collectives in the partitioner. I haven't been able to pin point where the state is lost.
But maybe we can do a post processing of the partitioned module to scan through collectives, detect duplicate channel ids and re-assign them with unique ones.

@ptoulme-aws
Copy link
Contributor

i have noticed this also when unrolling while loops with custom-calls inside of them. My solution was to add a channel id legalizer pass that enforces unique channel id for each collective

@Tixxx
Copy link
Contributor Author

Tixxx commented Jul 10, 2024

i have noticed this also when unrolling while loops with custom-calls inside of them. My solution was to add a channel id legalizer pass that enforces unique channel id for each collective

Thanks. I was thinking along the same line to have an uniquifyer to post-process the graph.

@nouiz
Copy link
Contributor

nouiz commented Jul 17, 2024

@ptoulme-aws Any hope of upstreaming or making publicly available your fix?

@ptoulme-aws
Copy link
Contributor

ptoulme-aws commented Jul 17, 2024 via email

@ptoulme-aws
Copy link
Contributor

This PR should unblock you - #15002

copybara-service bot pushed a commit that referenced this issue Jul 19, 2024
Imported from GitHub PR #15002

We have found it is not guaranteed after all transformations, partitioning, while loop unrolling etc that all channel ids will be unique.
Rather than debug this throughout XLA it is simpler to just add a pass that mandates unique channel ids, and changes channel ids to make them unique.

Issue: #14600
Copybara import of the project:

--
4764731 by ptoulme-aws <[email protected]>:

Add unique channel id enforcer pass

Merging this change closes #15002

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15002 from ptoulme-aws:unique_channel_id_new 4764731
PiperOrigin-RevId: 653939970
copybara-service bot pushed a commit that referenced this issue Jul 19, 2024
Imported from GitHub PR #15002

We have found it is not guaranteed after all transformations, partitioning, while loop unrolling etc that all channel ids will be unique.
Rather than debug this throughout XLA it is simpler to just add a pass that mandates unique channel ids, and changes channel ids to make them unique.

Issue: #14600
Copybara import of the project:

--
4764731 by ptoulme-aws <[email protected]>:

Add unique channel id enforcer pass

Merging this change closes #15002

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15002 from ptoulme-aws:unique_channel_id_new 4764731
PiperOrigin-RevId: 653939970
copybara-service bot pushed a commit that referenced this issue Jul 19, 2024
Imported from GitHub PR #15002

We have found it is not guaranteed after all transformations, partitioning, while loop unrolling etc that all channel ids will be unique.
Rather than debug this throughout XLA it is simpler to just add a pass that mandates unique channel ids, and changes channel ids to make them unique.

Issue: #14600
Copybara import of the project:

--
4764731 by ptoulme-aws <[email protected]>:

Add unique channel id enforcer pass

Merging this change closes #15002

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15002 from ptoulme-aws:unique_channel_id_new 4764731
PiperOrigin-RevId: 653939970
copybara-service bot pushed a commit that referenced this issue Jul 19, 2024
Imported from GitHub PR #15002

We have found it is not guaranteed after all transformations, partitioning, while loop unrolling etc that all channel ids will be unique.
Rather than debug this throughout XLA it is simpler to just add a pass that mandates unique channel ids, and changes channel ids to make them unique.

Issue: #14600
Copybara import of the project:

--
4764731 by ptoulme-aws <[email protected]>:

Add unique channel id enforcer pass

Merging this change closes #15002

COPYBARA_INTEGRATE_REVIEW=#15002 from ptoulme-aws:unique_channel_id_new 4764731
PiperOrigin-RevId: 653968079
@ptoulme-aws
Copy link
Contributor

@Tixxx My PR merged. can we close this issue now or leave it open for SPMD debug?

I would like to mention "i have noticed this also when unrolling while loops with custom-calls inside of them." if we leave this open for debug.

@Tixxx
Copy link
Contributor Author

Tixxx commented Jul 19, 2024

Let's leave it open for now. I think we need to revisit where the pass needs to be run. In our case, the duplicated id happens right after spmd partitioner, with hlo verifier running in different places in the pipeline, it will error out really early.
Also it seems like the pr was rolled back.

@ptoulme-aws
Copy link
Contributor

@Tixxx I have new PR. Where should we add it in GPU compiler? This will fix the peer to peer failure.

@Tixxx
Copy link
Contributor Author

Tixxx commented Jul 25, 2024

OK great, thanks, I think we will need to run it right after spmd partitioner, the error reported in the bug is caused by the partitioner giving duplicated IDs.
@frgossen Do you think this should be run as a sub-pass of spmd partitioner or as a stand-alone pass after spmd pipeline? I remember the hlo verifier is run after spmd so we might still get into the same error if running as a stand-alone pass.

@frgossen
Copy link
Member

We had a little bit of an offline discussion and the right place to fix this would be wherever the inlining happens. The channel ids should be unique at any point and changing them will cause problems for MPMD compilations. Note that the channel id is completely irrelevant for SMPD programs. I

@Tixxx
Copy link
Contributor Author

Tixxx commented Jul 25, 2024

Ok does that mean for the problem described in this issue, we won't rely on the channel id legalizer pass?

@Tixxx
Copy link
Contributor Author

Tixxx commented Jul 29, 2024

We had a little bit of an offline discussion and the right place to fix this would be wherever the inlining happens. The channel ids should be unique at any point and changing them will cause problems for MPMD compilations. Note that the channel id is completely irrelevant for SMPD programs. I

@frgossen Is there anyone from google already looking into fixing this when inlining happens?

@frgossen
Copy link
Member

I don't think anyone is looking into this atm. But I'm happy to review PRs that fix this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants