forked from epfLLM/Megatron-LLM
-
Notifications
You must be signed in to change notification settings - Fork 2
/
transformer.py
1347 lines (1177 loc) · 61.7 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
import math
from contextlib import nullcontext
from typing import Callable
import torch
import flash_attn
from torch.nn import functional as F
from einops import rearrange
from megatron import core, get_num_microbatches
from .module import MegatronModule
import megatron.core
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType, PositionEmbeddingType
from megatron.model import LayerNorm
from megatron.model import RMSNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, erf_gelu
# Extracted from: https://github.com/bigscience-workshop/Megatron-DeepSpeed
from .glu_activations import GLU_ACTIVATIONS
from megatron.model.positional_embeddings import precompute_freqs_cis, apply_rotary_emb
from flash_attn.bert_padding import pad_input, unpad_input_for_concatenated_sequences
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
# hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
def _args_to_kwargs(args):
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
init_method,
output_layer_init_method,
args,
world_size):
super(ParallelMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = megatron.core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
# GLU is a special activation that divides the dimension by a factor 2.
2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size,
bias=args.use_bias,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs(args),
world_size=world_size)
self.use_bias = args.use_bias
self.bias_gelu_fusion = args.bias_gelu_fusion
if args.glu_activation:
self.activation_func = GLU_ACTIVATIONS[args.glu_activation]
elif args.onnx_safe:
self.activation_func = erf_gelu
else:
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = megatron.core.tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
bias=args.use_bias,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs(args),
world_size=world_size)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.use_bias:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class CoreAttention(MegatronModule):
def __init__(self,
layer_number,
attn_mask_type=AttnMaskType.padding,
args=None,
world_size=None):
super(CoreAttention, self).__init__()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = megatron.core.mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with megatron.core.tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
world_size: int=None,
args=None):
super(ParallelAttention, self).__init__()
assert world_size is not None
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.sequence_parallel = args.sequence_parallel
self.use_flash_attn = args.use_flash_attn
self.sliding_window_size = args.sliding_window_size
self.num_attention_heads_kv = args.num_attention_heads_kv
self.num_attention_heads = args.num_attention_heads
self.seq_length = args.seq_length
self.packed_input = args.packed_input
if self.use_flash_attn:
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
# If sliding window is enabled, we need to make sure that the sliding window is supported.
if self.sliding_window_size is not None:
import inspect
# https://github.com/huggingface/transformers/blob/7e1eff7600085814eac65876d4d8a0e38c2f6ccc/src/transformers/models/mistral/modeling_mistral.py#L50C5-L50C32
assert "window_size" in list(inspect.signature(
flash_attn.flash_attn_func
).parameters), "The current flash attention version does not support sliding window attention, please update to the latest version."
assert self.use_flash_attn, "Sliding window attention is only supported with flash attention for now."
projection_size = args.kv_channels * args.num_attention_heads
qkv_projection_size = args.kv_channels * args.num_attention_heads + 2 * args.kv_channels * args.num_attention_heads_kv
# Per attention head and per partition values.
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = megatron.core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
qkv_projection_size,
bias=args.use_bias,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs(args),
world_size=world_size)
else:
assert attention_type == AttnType.cross_attn
self.query = megatron.core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
bias=args.use_bias,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs(args),
world_size=world_size)
self.key_value = megatron.core.tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
bias=args.use_bias,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs(args),
world_size=world_size,)
self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type, args, world_size)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
if self.use_flash_attn:
self.core_attention_flash = flash_attn.flash_attn_func
# Output.
self.dense = megatron.core.tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
bias=args.use_bias,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs(args),
world_size=world_size)
self.position_embedding_type = args.position_embedding_type
if self.position_embedding_type == PositionEmbeddingType.rotary:
self.freqs_cis = precompute_freqs_cis(
dim=args.hidden_size // args.num_attention_heads,
end=self.seq_length,
theta=args.rope_theta,
scaling_factor=args.rope_scaling_factor,
)
def _checkpointed_attention_forward(self,
query_layer,
key_layer,
value_layer,
attention_mask):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
hidden_states = megatron.core.tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self,
hidden_states,
attention_mask,
encoder_output=None,
inference_params=None,
position_ids=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
sq, b = mixed_x_layer.shape[:2]
# , we simply expand smaller keys and values tensors to have the usual shapes and then
# feed those tensor to the standard attention/flash attention
qkv = mixed_x_layer.view(sq, b, -1, self.num_attention_heads // self.num_attention_heads_kv + 2, self.hidden_size_per_attention_head)
query_layer = qkv[:, :, :, :-2]
key_layer = qkv[:, :, :, [-2]]
value_layer = qkv[:, :, :, [-1]]
key_layer = torch.broadcast_to(key_layer, query_layer.shape)
value_layer = torch.broadcast_to(value_layer, query_layer.shape)
query_layer, key_layer, value_layer = [rearrange(x, "seq_len batch group num_heads head_dim -> seq_len batch (group num_heads) head_dim",
head_dim=self.hidden_size_per_attention_head,) for x in [query_layer, key_layer, value_layer]]
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = megatron.core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
# ==================================
# Rotary embeddings
# ==================================
if self.position_embedding_type == PositionEmbeddingType.rotary:
query_layer, key_layer = apply_rotary_emb(query_layer, key_layer, self.freqs_cis, position_ids=position_ids)
# ==================================
# core attention computation
# ==================================
if not self.use_flash_attn:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
flash_attn_extra_kwargs = {}
# check if we need to use sliding window attention
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L353
if self.sliding_window_size is not None:
kv_seq_len = key_layer.shape[0]
if kv_seq_len > self.sliding_window_size:
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/mistral/modeling_mistral.py#L510C21-L510C89
flash_attn_extra_kwargs["window_size"] = (
self.sliding_window_size, self.sliding_window_size
)
# It will be truncated to the actual sequence length inside flash attention
# https://github.com/Dao-AILab/flash-attention/blob/83aef842beec1037eb8c1d9c3ef3ed8aae80b091/csrc/flash_attn/src/softmax.h#L159-L161
if self.packed_input:
# assume attention_mask is `attention_mask_in_length` (which is 2D with shape batch x seqlen)
seqlen, bsz = key_layer.shape[:2]
assert attention_mask.shape == (bsz, seqlen), f"attention_mask shape {attention_mask.shape} does not match expected attention_mask_in_length shape {(bsz, seqlen)}"
attention_mask_in_length = attention_mask
# following https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752
# to handle packed input for flash attention
qkv = torch.stack(
# each is [seqlen, bsz, nh, hd]
[query_layer, key_layer, value_layer], dim=2
) # [seqlen, bsz, 3, num_heads, hidden_size]
qkv = qkv.transpose(0, 1) # [bsz, seqlen, 3, num_heads, hidden_size]
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three n h -> b s (three n h)")
x_unpad, indices, cu_q_lens, max_s = unpad_input_for_concatenated_sequences(
x, attention_mask_in_length
)
x_unpad = rearrange(
x_unpad, "nnz (three n h) -> nnz three n h", three=3, n=nheads
)
if not self.sequence_parallel:
with megatron.core.tensor_parallel.get_cuda_rng_tracker().fork():
output_unpad = flash_attn.flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True,
**flash_attn_extra_kwargs
)
else:
output_unpad = flash_attn.flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True,
**flash_attn_extra_kwargs
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz n h -> nnz (n h)"),
indices, bsz, seqlen
),
"b s (n h) -> b s n h",
h=nheads,
)
context_layer = rearrange(output, "b s n h -> s b (n h)").contiguous()
else:
q, k, v = [rearrange(x, "s b n h -> b s n h").contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with megatron.core.tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(
q, k, v,
causal=True,
**flash_attn_extra_kwargs
)
else:
context_layer = self.core_attention_flash(
q, k, v,
causal=True,
**flash_attn_extra_kwargs
)
context_layer = rearrange(context_layer, 'b s n h -> s b (n h)').contiguous()
# =================
# Output. [sq, b, h]
# =================
# print(self.dense)
output, bias = self.dense(context_layer)
return output, bias
def dropout_add(x, residual, prob, training):
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
def dropout_add(x, residual, prob, training):
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def get_dropout_add(training):
def _dropout_add(x, residual, prob):
return dropout_add(x, residual, prob, training)
return _dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self,
init_method: Callable,
output_layer_init_method: Callable,
layer_number: int,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate: float=0.0,
world_size: int=None,
hidden_dropout: float=0.0,
args=None):
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.parallel_layernorm = args.parallel_layernorm
# Layernorm on the input data.
if args.use_rms_norm:
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.output_layernorm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
if self.parallel_layernorm:
self.mlp_layernorm = RMSNorm(args.hidden_size, eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
else:
self.input_layernorm = LayerNorm(args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
self.output_layernorm = LayerNorm(args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
if self.parallel_layernorm:
self.mlp_layernorm = LayerNorm(args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
self.use_post_ln = args.use_post_ln
if args.use_post_ln:
self.input_layernorm = torch.nn.Identity()
else:
self.output_layernorm = torch.nn.Identity()
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
world_size=world_size,
args=args)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
self.parallel_attn = args.parallel_attn
self.use_bias = args.use_bias
# Layernorm on the attention output
if not args.parallel_attn:
if not args.use_rms_norm:
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
else:
self.post_attention_layernorm = RMSNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn,
world_size=world_size,
args=args)
# Layernorm on the attention output.
if not args.use_rms_norm:
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
else:
self.post_inter_attention_layernorm = RMSNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.mlp = ParallelMLP(init_method, output_layer_init_method, args, world_size)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
position_ids=None):
##
# PRELIMINARIES - utilities to compute residual + dropout
##
# function to compute residual + dropout(x + bias)
def add_dropout(x, bias, residual, prob, make_viewless=False):
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
if self.use_bias:
bias = bias.expand_as(residual)
if self.drop_path is None:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(x, bias, residual, prob)
if make_viewless:
return core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
return output
out = torch.nn.functional.dropout(x + bias, p=prob, training=self.training)
return residual + self.drop_path(out)
elif self.drop_path is None:
with self.bias_dropout_add_exec_handler():
return dropout_add_func(x, residual, prob)
out = torch.nn.functional.dropout(x, p=prob, training=self.training)
return residual + self.drop_path(out)
# determine the dropout_add_func to use in the add_dropout function
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# triggerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if not self.use_bias:
dropout_add_func = get_dropout_add(self.training)
elif self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
##
# Transformer computation begins now.
##
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Get attention.
attention_output, attention_bias = self.self_attention(layernorm_output,
attention_mask,
inference_params=inference_params,
position_ids=position_ids)
# Determines the value of the next residual connection.
# if not parallel_attn: used after the post_attention_layernorm,
# else: used just before returning the output.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# dedicated mlp layernorm module
if self.parallel_layernorm:
layernorm_output = self.mlp_layernorm(hidden_states)
if self.parallel_attn:
# used only if layer is decoder and not residual_post_layernorm
# which seems a bit strange, but it's kept just in case for now
layernorm_input = attention_output
else:
layernorm_input = add_dropout(attention_output, attention_bias,
residual, self.hidden_dropout)
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
layernorm_input = add_dropout(attention_output, attention_bias,
residual, self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# Compute MLP.
# At this point, layernorm_output is:
# if layer is decoder: the post_inter_attention_layernorm output,
# elif parallel_layernorm: the mlp_layernorm output,
# elif parallel_attention: the input_layernorm tensor.
# else: the post_attention_layernorm output,
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.parallel_attn:
mlp_output = mlp_output + attention_output
elif self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = add_dropout(mlp_output, mlp_bias, residual, self.hidden_dropout,
make_viewless=True)
# Apply final layernorm, return.
output = self.output_layernorm(output)
return output
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
super().__init__()
self.layer_number = layer_number
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if megatron.core.mpu.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if megatron.core.mpu.is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and megatron.core.mpu.get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and megatron.core.mpu.get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
class ParallelTransformer(MegatronModule):
def __init__(self,
init_method: Callable,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=True,
drop_path_rate=0.0,
args=None,
model_type=None):
super(ParallelTransformer, self).__init__()
world_size = megatron.core.mpu.get_tensor_model_parallel_world_size()
assert args is not None
assert model_type is not None
self.layer_type = layer_type
self.model_type = model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl
# Store activation checkpointing flag.