-
Notifications
You must be signed in to change notification settings - Fork 4
/
jax_flash_attn_tpu.py
1695 lines (1495 loc) · 53.8 KB
/
jax_flash_attn_tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flash Attention TPU kernel."""
from __future__ import annotations
import dataclasses
import functools
from typing import Any, NamedTuple
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
NUM_LANES = 128
NUM_SUBLANES = 8
class SegmentIds(NamedTuple):
"""SegmentIds for Q and KV sequences.
SegmentIds are used to generate segment mask, which prevents attention between
different segments in the input sequence. Each array is a list of ids
(integers).
Only the token with the same id can attend to each other.
Attributes:
q: segment ids along the Q sequence.
kv: segment ids along the KV sequence.
"""
q: jax.Array # [batch_size, q_seq_len]
kv: jax.Array # [batch_size, kv_seq_len]
@dataclasses.dataclass(frozen=True)
class BlockSizes:
"""Tile sizes parameterizing FlashAttention kernels.
Those parameters have negligible effect on numerics, but affect performance
greatly.
"""
block_q: int
block_k_major: int
block_k: int
block_b: int
block_q_major_dkv: int | None = None
block_k_major_dkv: int | None = None
block_k_dkv: int | None = None
block_q_dkv: int | None = None
block_k_major_dq: int | None = None
block_k_dq: int | None = None
block_q_dq: int | None = None
def __post_init__(self):
def verify_major_minor(prefix, suffix, major, minor):
if minor > major:
raise ValueError(
f"{prefix}{suffix}={minor} should be smaller than"
f" {prefix}_major{suffix}={major}"
)
if major % minor != 0:
raise ValueError(
f"{prefix}{suffix}={minor} should divide"
f" {prefix}_major{suffix}={major}"
)
verify_major_minor("block_k", "", self.block_k_major, self.block_k)
if self.block_q_major_dkv is not None and self.block_q_dkv is not None:
verify_major_minor(
"block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv
)
if self.block_k_major_dkv is not None and self.block_k_dkv is not None:
verify_major_minor(
"block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv
)
if self.block_k_major_dq is not None and self.block_k_dq is not None:
verify_major_minor(
"block_k", "_dq", self.block_k_major_dq, self.block_k_dq
)
@property
def has_backward_blocks(self) -> bool:
backward_blocks = (
self.block_q_major_dkv,
self.block_k_major_dkv,
self.block_q_dkv,
self.block_k_dkv,
self.block_k_major_dq,
self.block_k_dq,
self.block_q_dq,
)
return all(b is not None for b in backward_blocks)
@classmethod
def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model):
# TODO(apaszke,sharadmv): Select better parameters based on a heuristic.
del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused.
return BlockSizes(
block_q=128,
block_k_major=128,
block_k=128,
block_b=1,
block_q_major_dkv=128,
block_k_major_dkv=128,
block_k_dkv=128,
block_q_dkv=128,
block_k_major_dq=128,
block_k_dq=128,
block_q_dq=128,
)
@functools.partial(
jax.jit,
static_argnames=[
"causal",
"sm_scale",
"block_sizes",
"debug",
],
)
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
*,
causal: bool = False,
sm_scale: float = 1.0,
block_sizes: BlockSizes | None = None,
debug: bool = False,
):
batch_size, num_heads, q_seq_len, d_model = q.shape
batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape
batch_size_v, num_heads_v, kv_seq_len_v, d_model_v = v.shape
if batch_size != batch_size_k or batch_size != batch_size_v:
raise ValueError(
f"Batch size mismatch: got {batch_size}, {batch_size_k} and"
f" {batch_size_v} (for q, k, v respectively)"
)
if num_heads != num_heads_k or num_heads != num_heads_v:
raise ValueError(
f"Head count mismatch: got {num_heads}, {num_heads_k},"
f" {num_heads_v} (for q, k, v respectively)"
)
if d_model != d_model_k:
raise ValueError(
f"Model dimension mismatch: got {d_model} and {d_model_k} (for q and k"
" respectively)"
)
if d_model != d_model_v:
raise NotImplementedError(
"V model dimension unequal to KV model dimension unsupported"
)
if kv_seq_len != kv_seq_len_v:
raise ValueError(
f"KV sequence length mismatch: got {kv_seq_len} and {kv_seq_len_v}"
)
if ab is not None:
if ab.shape != (batch_size, num_heads, q_seq_len, kv_seq_len):
raise ValueError(
f"Attention bias shape mismatch: expected ({batch_size=},"
f" {num_heads=}, {q_seq_len=}, {kv_seq_len=}), got {ab.shape}"
)
if segment_ids is not None:
if segment_ids.q.shape != (batch_size, q_seq_len):
raise ValueError(
f"Q segment ids shape mismatch: expected ({batch_size=},"
f" {q_seq_len=},), got {segment_ids.q.shape}"
)
if segment_ids.kv.shape != (batch_size, kv_seq_len):
raise ValueError(
f"KV segment ids shape mismatch: expected ({batch_size=},"
f" {kv_seq_len=},), got {segment_ids.kv.shape}"
)
if block_sizes is None:
block_sizes = BlockSizes.get_default(
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
)
return _flash_attention(
q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug
)
@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
def _flash_attention(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
sm_scale,
block_sizes,
debug,
):
return _flash_attention_impl(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
sm_scale,
block_sizes.block_b,
block_sizes.block_q,
block_sizes.block_k_major,
block_sizes.block_k,
debug,
)
def _flash_attention_fwd(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
sm_scale,
block_sizes,
debug,
):
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
o, l, m = _flash_attention(
q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug
)
return o, (q, k, v, ab, segment_ids, o, l, m)
def _flash_attention_bwd(
save_residuals: bool,
causal: bool,
sm_scale: float,
block_sizes: BlockSizes,
debug: bool,
residuals,
do,
):
"""VJP rule for FlashAttention."""
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
(q, k, v, ab, segment_ids, o, l, m) = residuals
if not block_sizes.has_backward_blocks:
raise ValueError(
"Program is being differentiated, but not all backward blocks are"
" specified"
)
di = jnp.sum(
o.astype(jnp.float32) * do.astype(jnp.float32), axis=-1
) # [batch_size, num_heads, q_seq_len]
dk, dv = _flash_attention_bwd_dkv(
q,
k,
v,
ab,
segment_ids,
l,
m,
do,
di,
block_q_major=block_sizes.block_q_major_dkv,
block_k_major=block_sizes.block_k_major_dkv,
block_k=block_sizes.block_k_dkv,
block_q=block_sizes.block_q_dkv,
sm_scale=sm_scale,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
debug=debug,
)
dq, ds = _flash_attention_bwd_dq(
q,
k,
v,
ab,
segment_ids,
l,
m,
do,
di,
block_q_major=block_sizes.block_q_dq,
block_k_major=block_sizes.block_k_major_dq,
block_k=block_sizes.block_k_dq,
sm_scale=sm_scale,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
debug=debug,
)
return dq, dk, dv, ds, None
_flash_attention.defvjp(fwd=_flash_attention_fwd, bwd=_flash_attention_bwd)
MIN_BLOCK_SIZE = 128
TRANS_B_DIM_NUMBERS = (((1,), (1,)), ((), ()))
def below_or_on_diag(r, r_blk_size, c, c_blk_size):
# A block is considered below or on diagonal as long as the bottom left
# corner of the block is below or on diagonal.
return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)
def _flash_attention_kernel(q_tile_ref, *args, **kwargs):
block_b = q_tile_ref.shape[0]
# If we're not going to tile the softmax, then we can avoid a bunch of VPU ops.
if kwargs["block_k"] == kwargs["kv_seq_len"]:
kernel = _flash_attention_kernel_single_batch_single_step
else:
kernel = _flash_attention_kernel_single_batch
for batch_idx in range(block_b):
kernel((batch_idx, 0), q_tile_ref, *args, **kwargs)
def _flash_attention_kernel_single_batch(
batch_idx: tuple[int, ...],
q_tile_ref,
k_tile_ref,
v_tile_ref,
ab_tile_ref,
q_segment_ids_tile_ref,
kv_segment_ids_tile_ref, # Input arrays
o_tile_ref, # Output arrays
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
l_ref: Any | None = None,
m_ref: Any | None = None,
*,
causal,
sm_scale,
block_k,
kv_seq_len,
mask_value,
):
block_k_major = k_tile_ref.shape[2]
block_q = q_tile_ref.shape[2]
head_dim = q_tile_ref.shape[-1]
kv_seq_idx = pl.program_id(3)
@pl.when(kv_seq_idx == 0)
def start_new_sequence():
m_scratch_ref[batch_idx] = jnp.full(
m_scratch_ref.shape[2:], -jnp.inf, jnp.float32
)
l_scratch_ref[batch_idx] = jnp.zeros(l_scratch_ref.shape[2:], jnp.float32)
acc_scratch_ref[batch_idx] = jnp.zeros(
acc_scratch_ref.shape[2:], jnp.float32
)
q_seq_idx = pl.program_id(2)
if causal:
should_run = below_or_on_diag(q_seq_idx, block_q, kv_seq_idx, block_k_major)
else:
should_run = True
@pl.when(should_run)
def run():
@functools.partial(
lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True
)
def body(i, _):
m_prev = m_scratch_ref[batch_idx]
l_prev = l_scratch_ref[batch_idx]
q = q_tile_ref[batch_idx] # [block_q, head_dim]
start_k = i * block_k
k = pl.load(
k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
) # [block_k, head_dim]
s = jax.lax.dot_general(
q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
) # [block_q, block_k]
# Add attention bias if needed.
# TODO(tanburn) Should the attention bias be added before or after
# multiplication by sm_scale?
if ab_tile_ref is not None:
ab = pl.load(
ab_tile_ref,
(*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k))
).astype(jnp.float32)
s += ab
if sm_scale != 1.0:
s *= sm_scale
mask = None
if q_segment_ids_tile_ref is not None:
repeats, rem = divmod(block_k, NUM_LANES)
if rem:
raise NotImplementedError(
f"kv block size must be a multiple of {NUM_LANES}"
)
q_segment_ids = pltpu.repeat(
q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1
) # [block_q, block_k].
kv_segment_ids = pl.load(
kv_segment_ids_tile_ref,
(batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)),
) # [1, block_k].
mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
if causal:
mask_shape = (block_q, block_k)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
row_ids += q_seq_idx * block_q
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
col_ids += kv_seq_idx * block_k_major + start_k
causal_mask = col_ids <= row_ids
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1].
m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128].
block_k_repeats, rem = divmod(block_k, MIN_BLOCK_SIZE)
if rem:
raise NotImplementedError(
f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}"
)
p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1))
alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128].
l_corr = alpha * l_prev
l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128]
head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
if rem:
if head_dim_repeats == 0:
l_broadcast = lambda l: l[:, :head_dim]
else:
raise NotImplementedError(
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
)
l_scratch_ref[batch_idx] = l_next
m_scratch_ref[batch_idx] = m_next
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
v = pl.load(
v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None))
)
o_curr = jax.lax.dot(
p.astype(v.dtype), v, preferred_element_type=jnp.float32
)
acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
@pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
def store_output():
o_tile_ref[batch_idx] = acc_scratch_ref[batch_idx].astype(o_tile_ref.dtype)
if l_ref is not None:
l_ref[batch_idx] = l_scratch_ref[batch_idx].astype(l_ref.dtype)
if m_ref is not None:
m_ref[batch_idx] = m_scratch_ref[batch_idx].astype(m_ref.dtype)
def _flash_attention_kernel_single_batch_single_step(
batch_idx: tuple[int, ...],
q_tile_ref,
k_tile_ref,
v_tile_ref,
ab_tile_ref,
q_segment_ids_tile_ref,
kv_segment_ids_tile_ref, # Input arrays
o_tile_ref, # Output arrays
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
l_ref: Any | None = None,
m_ref: Any | None = None,
*,
causal,
sm_scale,
block_k,
kv_seq_len,
mask_value,
):
block_k_major = k_tile_ref.shape[2]
block_q = q_tile_ref.shape[2]
scratch_refs = (m_scratch_ref, l_scratch_ref, acc_scratch_ref)
assert all(ref is None for ref in scratch_refs)
assert kv_seq_len == block_k_major == block_k
q = q_tile_ref[batch_idx] # [block_q, head_dim]
k = k_tile_ref[batch_idx] # [block_k, head_dim]
s = jax.lax.dot_general(
q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
) # [block_q, block_k]
if ab_tile_ref is not None:
s += ab_tile_ref[batch_idx].astype(jnp.float32)
if sm_scale != 1.0:
s *= sm_scale
mask = None
if q_segment_ids_tile_ref is not None:
repeats, rem = divmod(block_k, NUM_LANES)
if rem:
raise NotImplementedError(
f"kv block size must be a multiple of {NUM_LANES}"
)
q_segment_ids = pl.load(
q_segment_ids_tile_ref, (batch_idx[0],)
) # [block_q, NUM_LANES].
q_segment_ids = pltpu.repeat(
q_segment_ids, repeats, axis=1
) # [block_q, block_k].
kv_segment_ids = pl.load(
kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1))
) # [1, block_k].
mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
if causal:
q_seq_idx = pl.program_id(2)
mask_shape = (block_q, block_k)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
row_ids += q_seq_idx * block_q
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = col_ids <= row_ids
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
m = jnp.max(s, axis=1)[:, None]
p = jnp.exp(s - m)
l = jnp.sum(p, axis=1)[:, None]
p /= l
if m_ref is not None:
m_ref[batch_idx] = lax.broadcast_in_dim(m, m_ref.shape[2:], range(2))
if l_ref is not None:
l_ref[batch_idx] = lax.broadcast_in_dim(l, l_ref.shape[2:], range(2))
v = v_tile_ref[batch_idx]
o_tile_ref[batch_idx] = jax.lax.dot(
p.astype(v.dtype), v, preferred_element_type=jnp.float32
).astype(o_tile_ref.dtype)
def _flash_attention_impl(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
sm_scale,
block_b,
block_q,
block_k_major,
block_k,
debug,
):
batch_size, num_heads, q_seq_len, head_dim = q.shape
_, _, kv_seq_len, _ = k.shape
_verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
_verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
_verify_block("block_k", "kv_seq_len", block_k, kv_seq_len)
_verify_block("block_b", "batch", block_b, batch_size, should_divide=False)
# TODO(apaszke): Tile over heads as well.
grid = (
pl.cdiv(batch_size, block_b),
num_heads,
pl.cdiv(q_seq_len, block_q),
kv_seq_len // block_k_major,
)
def q_index_map(batch_index, head_index, q_seq_index, _):
return (batch_index, head_index, q_seq_index, 0)
def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
if causal:
# If the kv block is skipped, prefetch the next valid kv block, i.e. the
# 0th one to be used for the next block_q rows.
next_kv_index = lax.select(
below_or_on_diag(q_seq_index, block_q, kv_seq_index, block_k_major),
kv_seq_index,
0,
)
else:
next_kv_index = kv_seq_index
return (batch_index, head_index, next_kv_index, 0)
def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
if causal:
should_run = below_or_on_diag(
q_seq_index, block_q, kv_seq_index, block_k_major
)
# If the ab block is skipped, prefetch the next valid ab block, i.e. the
# 0th kv to be used for the next block_q rows.
next_q_index = lax.select(
should_run,
q_seq_index,
lax.select(
q_seq_index == (q_seq_len // block_q) - 1, 0, q_seq_index + 1
),
)
next_kv_index = lax.select(should_run, kv_seq_index, 0)
else:
next_q_index = q_seq_index
next_kv_index = kv_seq_index
return (batch_index, head_index, next_q_index, next_kv_index)
def o_index_map(batch_index, head_index, q_seq_index, _):
return (batch_index, head_index, q_seq_index, 0)
def lm_index_map(batch_index, head_index, q_seq_index, _):
return (batch_index, head_index, q_seq_index, 0)
kernel = functools.partial(
_flash_attention_kernel,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
sm_scale=sm_scale,
block_k=block_k,
kv_seq_len=kv_seq_len,
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
out_shape = [out_shape]
out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))]
if block_k != kv_seq_len:
scratch_shape = functools.partial(jax.ShapeDtypeStruct, dtype=jnp.float32)
m_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
l_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
acc_scratch = scratch_shape((block_b, 1, block_q, head_dim))
out_shape += [m_scratch, l_scratch, acc_scratch]
out_specs += [
pl.BlockSpec(lambda *_: (0, 0, 0, 0), m_scratch.shape),
pl.BlockSpec(lambda *_: (0, 0, 0, 0), l_scratch.shape),
pl.BlockSpec(lambda *_: (0, 0, 0, 0), acc_scratch.shape),
]
else:
out_shape += [None, None, None]
out_specs += [None, None, None]
if save_residuals:
out_specs = [
*out_specs,
pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
]
l = jax.ShapeDtypeStruct(
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
)
m = jax.ShapeDtypeStruct(
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
)
out_shape = (*out_shape, l, m)
ab_block_spec = (
pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major))
if ab is not None else None)
q_segment_ids_spec = kv_segment_ids_spec = None
q_segment_ids = kv_segment_ids = None
if segment_ids is not None:
def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
del head_index
return (batch_index, q_seq_index, 0)
def kv_segment_ids_index_map(
batch_index, head_index, q_seq_index, kv_seq_index
):
del head_index
if causal:
next_kv_index = lax.select(
below_or_on_diag(q_seq_index, block_q, kv_seq_index, block_k_major),
kv_seq_index,
0,
)
else:
next_kv_index = kv_seq_index
return (batch_index, 0, next_kv_index)
q_segment_ids_spec = pl.BlockSpec(
q_segment_ids_index_map, (block_b, block_q, NUM_LANES)
)
kv_segment_ids_spec = pl.BlockSpec(
kv_segment_ids_index_map, (block_b, NUM_SUBLANES, block_k_major)
)
q_segment_ids = jax.lax.broadcast_in_dim(
segment_ids.q,
(batch_size, q_seq_len, NUM_LANES),
(
0,
1,
),
)
kv_segment_ids = jax.lax.broadcast_in_dim(
segment_ids.kv,
(batch_size, NUM_SUBLANES, kv_seq_len),
(
0,
2,
),
)
in_specs = [
pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)),
pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
ab_block_spec,
q_segment_ids_spec,
kv_segment_ids_spec,
]
o, *aux = pl.pallas_call(
kernel,
out_shape=out_shape,
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
debug=debug,
mosaic_params=dict(
dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
if save_residuals:
l, m = (v[..., 0] for v in aux[-2:])
return (o, l, m)
else:
return o
def _flash_attention_dkv_kernel(
q_tile_ref,
k_tile_ref,
v_tile_ref,
ab_tile_ref,
q_segment_ids_tile_ref,
kv_segment_ids_tile_ref,
l_tile_ref,
m_tile_ref,
do_tile_ref,
di_tile_ref,
dk_tile_ref,
dv_tile_ref,
dk_scratch_ref,
dv_scratch_ref,
*,
sm_scale: float,
causal: bool,
mask_value: float,
q_seq_len: int,
block_q: int,
block_k: int,
):
_, _, block_q_major, _ = q_tile_ref.shape
_, _, block_k_major, _ = k_tile_ref.shape
q_seq_index = pl.program_id(axis=3)
kv_seq_index = pl.program_id(axis=2)
@pl.when(q_seq_index == 0)
def start_new_sequence():
dk_scratch_ref[:, :] = jnp.zeros(dk_scratch_ref.shape, dk_scratch_ref.dtype)
dv_scratch_ref[:, :] = jnp.zeros(dv_scratch_ref.shape, dv_scratch_ref.dtype)
def q_body(j, _):
start_q = j * block_q
def k_body(i, _):
start_k = i * block_k
k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None)))
v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None)))
q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
) # [block_q, head_dim]
l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
) # [block_q, 128]
m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
) # [block_q, 128]
do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
) # [block_q, 128]
di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None))
).astype(jnp.float32) # [block_q, 128]
capped_logits = lax.dot_general(
q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
) # [block_q_major, block_k]
if ab_tile_ref is not None:
ab = pl.load(
ab_tile_ref,
(
0,
0,
pl.dslice(j * block_q, block_q),
pl.dslice(i * block_k, block_k),
),
).astype(jnp.float32)
capped_logits += ab
if sm_scale != 1.0:
capped_logits *= sm_scale
mask = None
if q_segment_ids_tile_ref is not None:
repeats, rem = divmod(block_k, NUM_LANES)
if rem:
raise NotImplementedError(
)
q_segment_ids = pl.load(
q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None))
) # [block_q, NUM_LANES].
q_segment_ids = pltpu.repeat(
q_segment_ids, repeats, axis=1
) # [block_q, block_k].
kv_segment_ids = pl.load(
kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k))
) # [1, block_k].
mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
if causal:
mask_shape = (block_q, block_k)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
row_ids += q_seq_index * block_q_major + start_q
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
col_ids += kv_seq_index * block_k_major + start_k
causal_mask = col_ids <= row_ids
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
capped_logits = (
capped_logits
if mask is None
else capped_logits + jnp.where(mask, 0.0, mask_value)
)
p = jnp.exp(
capped_logits - pltpu.repeat(m, block_k // MIN_BLOCK_SIZE, axis=1)
)
p = p * pltpu.repeat(
1 / l, block_k // MIN_BLOCK_SIZE, axis=1
) # [block_q_major, block_k_major]
dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32)
pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)),
pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)))
+ dv.astype(dv_scratch_ref.dtype))
# di: [block_q, 128]
# do: [block_q, head_dim]
# v: [block_k_major, head_dim]
dp = lax.dot_general(
do, v, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32
)
ds = (dp - pltpu.repeat(di, block_k // MIN_BLOCK_SIZE, axis=1)) * p
if sm_scale != 1.0:
ds = ds * sm_scale
# ds: [block_q_major, block_k_major]
# q: [block_q_major, head_dim]
dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32)
pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)),
pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)))
+ dk.astype(dk_scratch_ref.dtype))
lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True)
if causal:
should_run = below_or_on_diag(
q_seq_index, block_q_major, kv_seq_index, block_k_major
)
else:
should_run = True
@pl.when(should_run)
def run():
lax.fori_loop(0, block_q_major // block_q, q_body, None, unroll=True)
@pl.when(q_seq_index == q_seq_len // block_q_major - 1)
def end_of_q_sequence():
dv_tile_ref[0, 0, :, :] = dv_scratch_ref[...].astype(dv_tile_ref)
dk_tile_ref[0, 0, :, :] = dk_scratch_ref[...].astype(dk_tile_ref)
def _flash_attention_bwd_dkv(
q,
k,
v,
ab,
segment_ids,
l,
m,
do,
di,
*,
block_q_major: int | None,
block_q: int | None,
block_k_major: int | None,
block_k: int | None,
sm_scale: float,
causal: bool = False,
mask_value: float = DEFAULT_MASK_VALUE,
debug: bool = False,
):
batch_size, num_heads, q_seq_len, head_dim = q.shape
_, _, kv_seq_len, _ = k.shape
_verify_block("block_q_major_dkv", "q_seq_len", block_q_major, q_seq_len)
_verify_block("block_q_dkv", "q_seq_len", block_q, q_seq_len)
_verify_block("block_k_major_dkv", "kv_seq_len", block_k_major, kv_seq_len)
_verify_block("block_k_dkv", "kv_seq_len", block_k, kv_seq_len)
# Broadcast out scalar values
m = jnp.broadcast_to(m[..., None], (*m.shape, MIN_BLOCK_SIZE))
l = jnp.broadcast_to(l[..., None], (*l.shape, MIN_BLOCK_SIZE))
# Preprocess contraction for bwd pass
di = jnp.broadcast_to(di[..., None], (*di.shape, MIN_BLOCK_SIZE))
# kv index needs to be before q index since q index is the contractng
# dimension.
grid = (
batch_size,
num_heads,
kv_seq_len // block_k_major,
q_seq_len // block_q_major,
)
def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
if causal:
# If the q block is skipped, stay at the 0th q block.
next_q_index = lax.select(
below_or_on_diag(
q_seq_index, block_q_major, kv_seq_index, block_k_major
),
q_seq_index,
0,
)
else:
next_q_index = q_seq_index
return (batch_index, head_index, next_q_index, 0)
qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
assert qo_spec.block_shape is not None
assert q.ndim == len(qo_spec.block_shape)
do_spec = qo_spec
assert do.ndim == len(qo_spec.block_shape)
def kv_index_map(batch_index, head_index, kv_seq_index, _):
return (batch_index, head_index, kv_seq_index, 0)
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
assert kv_spec.block_shape is not None
assert k.ndim == len(kv_spec.block_shape)
assert v.ndim == len(kv_spec.block_shape)
def lm_index_map(batch_index, head_index, _, q_seq_index):
return (batch_index, head_index, q_seq_index, 0)
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert lm_spec.block_shape is not None
assert l.ndim == len(lm_spec.block_shape)
assert m.ndim == len(lm_spec.block_shape)
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert di_spec.block_shape is not None
assert di.ndim == len(di_spec.block_shape)
def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index):