Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange Sharding with pipeline parallel around Scatter op - TP/PP/DP #15048

Open
ptoulme-aws opened this issue Jul 17, 2024 · 3 comments
Open
Assignees

Comments

@ptoulme-aws
Copy link
Contributor

I am seeing very strange sharding with pipeline parallel and tensor, data parallel.

Below is the HLO exactly before partitioning:

  while.9466 = (s32[], bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, /*index=5*/bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, /*index=10*/bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, /*index=15*/bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, /*index=20*/bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, /*index=25*/bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, /*index=30*/bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, /*index=35*/bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, /*index=40*/bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, /*index=45*/bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, /*index=50*/bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, /*index=55*/bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,2,64,128]{3,2,1,0}, bf16[8,4,2,64,128]{4,3,2,1,0}, bf16[11,4,2,64,128]{4,3,2,1,0}, /*index=60*/bf16[11,4,2,64,128]{4,3,2,1,0}, bf16[11,4,2,1,64,64]{5,4,3,2,1,0}, bf16[11,8,4,2,64,128]{5,4,3,2,1,0}, bf16[11,8,4,2,1,64,64]{6,5,4,3,2,1,0}, s32[11]{0}, /*index=65*/bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, /*index=70*/bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, /*index=75*/bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, /*index=80*/bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, /*index=85*/bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, /*index=90*/bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, /*index=95*/bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, /*index=100*/bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, /*index=105*/bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, /*index=110*/bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, bf16[4,128]{1,0}, bf16[4,128,512]{2,1,0}, /*index=115*/bf16[4,128,512]{2,1,0}, bf16[4,512,128]{2,1,0}, bf16[4,128]{1,0}, bf16[4,3,128,32,4]{4,3,2,1,0}, bf16[4,128,32,4]{3,2,1,0}, /*index=120*/bf16[4,128]{1,0}) while(tuple.4284), condition=region_163.9341, body=region_50.5942, sharding={{replicated}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=5*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=10*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=15*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=20*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=25*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=30*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=35*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=40*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=45*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=50*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=55*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=60*/{devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,1,1,1,8]0,1,2,3,4,5,6,7,32,33,34,35,36,37,38,39,8,9,10,11,12,13,14,15,40,41,42,43,44,45,46,47,16,17,18,19,20,21,22,23,48,49,50,51,52,53,54,55,24,25,26,27,28,29,30,31,56,57,58,59,60,61,62,63 last_tile_dim_replicate}, {devices=[1,1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,1,4,2,1,1,1,8]0,1,2,3,4,5,6,7,32,33,34,35,36,37,38,39,8,9,10,11,12,13,14,15,40,41,42,43,44,45,46,47,16,17,18,19,20,21,22,23,48,49,50,51,52,53,54,55,24,25,26,27,28,29,30,31,56,57,58,59,60,61,62,63 last_tile_dim_replicate}, {replicated}, /*index=65*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=70*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=75*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=80*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=85*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=90*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=95*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=100*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=105*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=110*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=115*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=120*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/layer[StackedTransformerLayer]/while[cond_nconsts=0 body_nconsts=62]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=392}
  get-tuple-element.9525 = bf16[8,4,2,64,128]{4,3,2,1,0} get-tuple-element(while.9466), index=58, sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/layer[StackedTransformerLayer]/while[cond_nconsts=0 body_nconsts=62]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=392}
  copy.132 = bf16[8,4,2,64,128]{4,3,2,1,0} copy(get-tuple-element.9525), sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/sharding_constraint[sharding=GSPMDSharding({replicated}) resource_env=ResourceEnv(mesh=Mesh(\'data\': 2, \'expert\': 1, \'fsdp\': 1, \'seq\': 1, \'pipeline\': 4, \'model\': 8), ()) unconstrained_dims={0, 1, 2, 3}]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/utils.py" source_line=414}
  slice.56 = bf16[8,1,2,64,128]{4,3,2,1,0} slice(copy.132), slice={[0:8], [0:1], [0:2], [0:64], [0:128]}, sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(root[Learner])/jvp(model[Model])/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/jit(_pad)/pad[padding_config=((0, 0, 0), (0, -3, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=246}
  reshape.1062 = bf16[8,2,64,128]{3,2,1,0} reshape(slice.56), sharding={devices=[1,2,8,1,4]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  transpose.9599 = bf16[2,8,64,128]{3,2,0,1} transpose(reshape.1062), dimensions={1,0,2,3}, sharding={devices=[2,1,8,1,4]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/transpose[permutation=(1, 0, 2, 3)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=523}
  reshape.9600 = bf16[16,64,128]{2,1,0} reshape(transpose.9599), sharding={devices=[2,1,1,32]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reshape[new_sizes=(16, 64, 128) dimensions=None]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=522}
  copy.131 = bf16[16,64,128]{2,1,0} copy(reshape.9600), sharding={devices=[2,1,1,32]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/emb[TransformerTextEmbeddings]/token_emb[Embedding]/sharding_constraint[sharding=GSPMDSharding({devices=[2,1,1,32]<=[64] last_tile_dim_replicate}) resource_env=ResourceEnv(mesh=Mesh(\'data\': 2, \'expert\': 1, \'fsdp\': 1, \'seq\': 1, \'pipeline\': 4, \'model\': 8), ()) unconstrained_dims=set()]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/utils.py" source_line=414}
  scatter.9606 = bf16[1600,128]{1,0} scatter(broadcast.2639, reshape.774, copy.131), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=region_165.9602, sharding={devices=[8,1,8]<=[8,8]T(1,0) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/emb[TransformerTextEmbeddings]/token_emb[Embedding]/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(2,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/layers.py" source_line=1245}

After partitioning I see a strange all-reduce over half the TP group and then collective-permutes. Then a strange all-gather over a group of random 8 workers. I have tried debugging this in SPMD logs with VLOG=10, but I have not found the reason behind this sharding.

I would have expected an all-gather of the TP group before the scatter op. Does anyone have any explanation or pointers on how to debug this?

After partition

  while.1 = (s32[], bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, /*index=5*/bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=10*/bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, /*index=15*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, /*index=20*/bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=25*/bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, /*index=30*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, /*index=35*/bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, /*index=40*/bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=45*/bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, /*index=50*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, /*index=55*/bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,1,8,128]{3,2,1,0}, bf16[8,1,1,8,128]{4,3,2,1,0}, bf16[11,1,1,8,128]{4,3,2,1,0}, /*index=60*/bf16[11,1,1,8,128]{4,3,2,1,0}, bf16[11,1,1,1,64,64]{5,4,3,2,1,0}, bf16[11,8,1,1,8,128]{5,4,3,2,1,0}, bf16[11,8,1,1,1,64,64]{6,5,4,3,2,1,0}, s32[11]{0}, /*index=65*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, /*index=70*/bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=75*/bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, /*index=80*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, /*index=85*/bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, /*index=90*/bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=95*/bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, /*index=100*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, /*index=105*/bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, /*index=110*/bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, bf16[1,128]{1,0}, bf16[1,64,64]{2,1,0}, /*index=115*/bf16[1,64,64]{2,1,0}, bf16[1,64,64]{2,1,0}, bf16[1,128]{1,0}, bf16[1,3,64,4,4]{4,3,2,1,0}, bf16[1,64,4,4]{3,2,1,0}, /*index=120*/bf16[1,128]{1,0}) while(tuple.3), condition=region_163.9341_spmd, body=region_50.5942_spmd, sharding={{replicated}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=5*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=10*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=15*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=20*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=25*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=30*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=35*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=40*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=45*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=50*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=55*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=60*/{devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,4,2,1,1,1,8]0,1,2,3,4,5,6,7,32,33,34,35,36,37,38,39,8,9,10,11,12,13,14,15,40,41,42,43,44,45,46,47,16,17,18,19,20,21,22,23,48,49,50,51,52,53,54,55,24,25,26,27,28,29,30,31,56,57,58,59,60,61,62,63 last_tile_dim_replicate}, {devices=[1,1,4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[1,1,4,2,1,1,1,8]0,1,2,3,4,5,6,7,32,33,34,35,36,37,38,39,8,9,10,11,12,13,14,15,40,41,42,43,44,45,46,47,16,17,18,19,20,21,22,23,48,49,50,51,52,53,54,55,24,25,26,27,28,29,30,31,56,57,58,59,60,61,62,63 last_tile_dim_replicate}, {replicated}, /*index=65*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=70*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=75*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=80*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=85*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=90*/{devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=95*/{devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, /*index=100*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=105*/{devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, /*index=110*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,2,8]<=[2,4,8]T(1,0,2)}, /*index=115*/{devices=[4,2,8]<=[2,4,8]T(1,0,2)}, {devices=[4,8,2]<=[2,32]T(1,0)}, {devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, {devices=[4,1,2,8,1]<=[2,4,8]T(1,0,2)}, {devices=[4,2,8,1]<=[2,4,8]T(1,0,2)}, /*index=120*/{devices=[4,1,16]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/layer[StackedTransformerLayer]/while[cond_nconsts=0 body_nconsts=62]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=392}
  get-tuple-element.200 = bf16[8,1,1,8,128]{4,3,2,1,0} get-tuple-element(while.1), index=58, sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/layer[StackedTransformerLayer]/while[cond_nconsts=0 body_nconsts=62]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=392}
  copy.426 = bf16[8,1,1,8,128]{4,3,2,1,0} copy(get-tuple-element.200), sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/sharding_constraint[sharding=GSPMDSharding({replicated}) resource_env=ResourceEnv(mesh=Mesh(\'data\': 2, \'expert\': 1, \'fsdp\': 1, \'seq\': 1, \'pipeline\': 4, \'model\': 8), ()) unconstrained_dims={0, 1, 2, 3}]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/utils.py" source_line=414}
  copy.427 = bf16[8,1,1,8,128]{4,3,2,1,0} copy(copy.426), sharding={devices=[1,4,2,8,1]<=[2,4,8]T(1,0,2)}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(root[Learner])/jvp(model[Model])/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/jit(_pad)/pad[padding_config=((0, 0, 0), (0, -3, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=246}
  collective-permute.10 = bf16[8,1,1,8,128]{4,3,2,1,0} collective-permute(copy.427), channel_id=288, source_target_pairs={{0,0},{1,4},{2,32},{3,36},{4,8},{5,12},{6,40},{7,44},{32,16},{33,20},{34,48},{35,52},{36,24},{37,28},{38,56},{39,60},{8,1},{9,5},{10,33},{11,37},{12,9},{13,13},{14,41},{15,45},{40,17},{41,21},{42,49},{43,53},{44,25},{45,29},{46,57},{47,61},{16,2},{17,6},{18,34},{19,38},{20,10},{21,14},{22,42},{23,46},{48,18},{49,22},{50,50},{51,54},{52,26},{53,30},{54,58},{55,62},{24,3},{25,7},{26,35},{27,39},{28,11},{29,15},{30,43},{31,47},{56,19},{57,23},{58,51},{59,55},{60,27},{61,31},{62,59},{63,63}}, sharding={devices=[1,4,2,8,1]0,4,32,36,8,12,40,44,16,20,48,52,24,28,56,60,1,5,33,37,9,13,41,45,17,21,49,53,25,29,57,61,2,6,34,38,10,14,42,46,18,22,50,54,26,30,58,62,3,7,35,39,11,15,43,47,19,23,51,55,27,31,59,63}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  constant.6732 = bf16[] constant(0), metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  broadcast.818 = bf16[8,1,1,8,128]{4,3,2,1,0} broadcast(constant.6732), dimensions={}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  select.93 = bf16[8,1,1,8,128]{4,3,2,1,0} select(broadcast.817, collective-permute.10, broadcast.818), metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  all-reduce.134 = bf16[8,1,1,8,128]{4,3,2,1,0} all-reduce(select.93), channel_id=289, replica_groups={{0,1,2,3},{4,5,6,7},{32,33,34,35},{36,37,38,39},{8,9,10,11},{12,13,14,15},{40,41,42,43},{44,45,46,47},{16,17,18,19},{20,21,22,23},{48,49,50,51},{52,53,54,55},{24,25,26,27},{28,29,30,31},{56,57,58,59},{60,61,62,63}}, use_global_device_ids=true, to_apply=add.99, sharding={devices=[1,1,2,8,1,4]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  reshape.2138 = bf16[8,1,8,128]{3,2,1,0} reshape(all-reduce.134), sharding={devices=[1,2,8,1,4]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reduce_sum[axes=(1,)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=247}
  transpose.141 = bf16[1,8,8,128]{3,2,0,1} transpose(reshape.2138), dimensions={1,0,2,3}, sharding={devices=[2,1,8,1,4]<=[2,4,8]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/transpose[permutation=(1, 0, 2, 3)]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=523}
  collective-permute.11 = bf16[1,8,8,128]{3,2,0,1} collective-permute(transpose.141), channel_id=290, source_target_pairs={{0,0},{1,1},{2,2},{3,3},{4,4},{5,5},{6,6},{7,7},{32,8},{33,9},{34,10},{35,11},{36,12},{37,13},{38,14},{39,15},{8,16},{9,17},{10,18},{11,19},{12,20},{13,21},{14,22},{15,23},{40,24},{41,25},{42,26},{43,27},{44,28},{45,29},{46,30},{47,31},{16,32},{17,33},{18,34},{19,35},{20,36},{21,37},{22,38},{23,39},{48,40},{49,41},{50,42},{51,43},{52,44},{53,45},{54,46},{55,47},{24,48},{25,49},{26,50},{27,51},{28,52},{29,53},{30,54},{31,55},{56,56},{57,57},{58,58},{59,59},{60,60},{61,61},{62,62},{63,63}}, sharding={devices=[2,1,8,1,4]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reshape[new_sizes=(16, 64, 128) dimensions=None]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=522}
  all-gather.142 = bf16[1,8,64,128]{3,2,0,1} all-gather(collective-permute.11), channel_id=291, replica_groups={{0,4,8,12,16,20,24,28},{1,5,9,13,17,21,25,29},{2,6,10,14,18,22,26,30},{3,7,11,15,19,23,27,31},{32,36,40,44,48,52,56,60},{33,37,41,45,49,53,57,61},{34,38,42,46,50,54,58,62},{35,39,43,47,51,55,59,63}}, dimensions={2}, use_global_device_ids=true, sharding={devices=[2,1,1,1,32]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reshape[new_sizes=(16, 64, 128) dimensions=None]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=522}
  reshape.2142 = bf16[8,64,128]{2,1,0} reshape(all-gather.142), sharding={devices=[2,1,1,32]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reshape[new_sizes=(16, 64, 128) dimensions=None]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=522}
  copy.428 = bf16[8,64,128]{2,1,0} copy(reshape.2142), sharding={devices=[2,1,1,32]<=[64] last_tile_dim_replicate}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/emb[TransformerTextEmbeddings]/token_emb[Embedding]/sharding_constraint[sharding=GSPMDSharding({devices=[2,1,1,32]<=[64] last_tile_dim_replicate}) resource_env=ResourceEnv(mesh=Mesh(\'data\': 2, \'expert\': 1, \'fsdp\': 1, \'seq\': 1, \'pipeline\': 4, \'model\': 8), ()) unconstrained_dims=set()]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/utils.py" source_line=414}
  scatter = bf16[200,128]{1,0} scatter(select.94, subtract.270, copy.428), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=region_165.9602, sharding={replicated}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/emb[TransformerTextEmbeddings]/token_emb[Embedding]/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(2,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/layers.py" source_line=1245}
@ptoulme-aws
Copy link
Contributor Author

Sharding metadata looks correct for my ops also. I do not understand why the all-gather is not over the TP group so {0-7} in first case

@ptoulme-aws
Copy link
Contributor Author

I compared on CPU HLO runner the above HLO with those ops and one with a pass I wrote to simplify to all-gather over TP group. according to HLO runner they produce equal results

  get-tuple-element.3416 = bf16[8,1,1,8,128]{4,3,2,1,0} get-tuple-element(tuple.107), index=58, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/layer[StackedTransformerLayer]/while[cond_nconsts=0 body_nconsts=62]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=392}
  reshape.7052 = bf16[1,8,8,128]{3,2,1,0} reshape(get-tuple-element.3416)
  all-gather.148 = bf16[1,8,64,128]{3,2,1,0} all-gather(reshape.7052), channel_id=291, replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15},{16,17,18,19,20,21,22,23},{24,25,26,27,28,29,30,31},{32,33,34,35,36,37,38,39},{40,41,42,43,44,45,46,47},{48,49,50,51,52,53,54,55},{56,57,58,59,60,61,62,63}}, dimensions={2}, use_global_device_ids=true
  reshape.6041 = bf16[512,1,128]{2,1,0} reshape(all-gather.148), metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/transformer[PipelinedTransformerLayer]/pipeline[_TransformerPipeline]/reshape[new_sizes=(16, 64, 128) dimensions=None]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/pipeline.py" source_line=522}
  scatter = bf16[200,128]{1,0} scatter(broadcast.4413, reshape.6039, reshape.6041), update_window_dims={1,2}, inserted_window_dims={}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=region_165.9602, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/transpose(jvp(model[Model]))/decoder[Decoder]/emb[TransformerTextEmbeddings]/token_emb[Embedding]/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(2,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/shared_new/ptoulme/axlearn/axlearn/axlearn/common/layers.py" source_line=1245}

@golechwierowicz
Copy link
Member

In general SPMD partitioner is hard to debug and personally I did not find a better method than inserting print statements to the pass methodologically.

It does look like the reshape -> transpose -> reshape can be simplified to a simple reshape. If it is simplified would the SPMD partitioner output the simpler collective chain?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants