Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop #14332

Open
qGentry opened this issue Jul 1, 2024 · 5 comments
Assignees

Comments

@qGentry
Copy link

qGentry commented Jul 1, 2024

Hi, I have following setup:

  • Transformer model with N layers scanned over input
  • fully sharded data parallel sharding
  • asynchronous communications (latency-hiding scheduler, pipelined all-gather,all-reduce,reduce-scatter)

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_triton_gemm=false 
--xla_gpu_enable_command_buffer= 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

This works correctly and indeed hide layers' weights all-gather and gradient reduce-scatter behind computations.

Problems are starting to arise when I try to use gradient accumulation in this setup. It is implemented like this:

    grads_sum = jax.tree_map(jnp.zeros_like, train_state.params)
    train_state, grads_sum = jax.lax.fori_loop(
        lower=0,
        upper=num_minibatches_in_batch,
        body_fun=_loop_body,
        init_val=(train_state, grads_sum),
        unroll=False,
    )

    mean_grads = jax.tree_map(lambda x: x / num_minibatches_in_batch, grads_sum)

When I set gradient accumulation factor (num_minibatches_in_batch in this snippet) to value greater than 1, I'm getting following error during compilation:

2024-07-01 12:57:35.488299: F external/xla/xla/service/collective_pipeliner.cc:675] Check failed: last_cloned != nullptr (0 vs. nullptr)

Here is --xla_dump_to result:
xla_dump.tgz

One important fact here is that if I set unroll in jax.lax.fori_loop to True, then there is no compilation error and everything works. But this obviously leads to additional memory usage proportional to gradient accumulation factor so this hack doesn't seem to be viable.

My hypothesis is that when using --xla_gpu_enable_while_loop_double_buffering=true with pipelined collectives and latency hiding scheduler, XLA compiler tries to double buffer this fori_loop which is actually undesired behavior.

Basically, there are two problems:

  • Bug in compiler that leads to hard-to-parse source of error in JAX code
  • If my hypothesis is correct, I would like to have mechanism to disable while_loop_double_buffering for specific loops (like gradient accumulation loop) or enable only for specific loops (like layers loop)

I've tested this on JAX 0.4.29 and 0.4.30.

@qGentry
Copy link
Author

qGentry commented Jul 1, 2024

related JAX issue:
google/jax#22210

@qGentry
Copy link
Author

qGentry commented Jul 1, 2024

Actually, this problem persists even with --xla_gpu_enable_while_loop_double_buffering=false, so maybe it is not source of the problem.

@qGentry qGentry changed the title Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop [XLA:GPU] Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop Jul 16, 2024
@rosiezou rosiezou self-assigned this Jul 22, 2024
@Tixxx
Copy link
Contributor

Tixxx commented Jul 24, 2024

For the compilation error in collective_pipeliner, can you try with xla_gpu_run_post_layout_collective_pipeliner=false ?

@rosiezou rosiezou assigned qGentry and unassigned rosiezou Jul 25, 2024
@rosiezou
Copy link

Hi Filipp, could you try TJ's suggestion and update this issue with the results and any errors if applicable?

@qGentry
Copy link
Author

qGentry commented Jul 29, 2024

Hi guys, looks like this flag was added in very recent commit and has not been added to JAX latest release (0.4.30). I'll wait for JAX 0.4.31 to test it. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants