-
Notifications
You must be signed in to change notification settings - Fork 357
/
Copy pathattention.py
3676 lines (3299 loc) · 167 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Attention."""
import collections
from contextlib import nullcontext
import functools
from importlib.metadata import version
import math
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
from pkg_resources import packaging
import torch
import torch.nn.functional as F
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
QKVLayout,
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
split_tensor_along_dim,
get_device_compute_capability,
get_default_init_method,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
AttnBiasTypes,
QKVLayouts,
dist_group_type,
TE_DType,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
get_distributed_rank,
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
@torch.no_grad()
def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
----------
num_heads: int
Number of heads.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32.
Returns
----------
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
`alibi_slopes` is in [batch_size, num_heads], then the bias is in
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
if alibi_slopes is not None:
_alibi_cache["_alibi_slopes"] = alibi_slopes
else:
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
_alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
_alibi_cache["_num_heads"] = num_heads
_alibi_cache["_alibi_slopes_require_update"] = False
if _alibi_cache["_alibi_bias_require_update"]:
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
bias = torch.arange(
1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(
1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1]
containing the indices for the valid tokens.
"""
mask = mask.squeeze(1).squeeze(1)
bs, seqlen = mask.shape
reduced_mask = mask.sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
mask = mask.reshape(-1)
indices = mask.nonzero()
indices = indices.unsqueeze(-1)
num_nonzeros = indices.shape[0]
pad_amount = bs * seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * seqlen))
return cu_seqlens, indices
def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
"""
Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32
tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for
the valid tokens in a batch.
"""
bs = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
indices = [i*max_seqlen + ii for i,j in enumerate(seqlens) for ii in range(j)]
indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(
dtype=torch.int64, device="cuda")
num_nonzeros = indices.shape[0]
pad_amount = bs * max_seqlen - num_nonzeros
indices = F.pad(input=indices, pad=(0, 0, 0, 0, 0, pad_amount),
mode="constant", value=float(bs * max_seqlen))
return indices
@functools.lru_cache
def _get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
device: torch.device,
) -> torch.Tensor:
"""Cumulative sequence lengths in full data batch
All sequences in batch have the maximum sequence length.
"""
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Packs the given tensor using the `indices`.
"""
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat((tensor, padding_indice), dim=0)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Packs the given 2 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
return t1_packed, t2_packed
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs the given 3 tensors using the `indices`.
"""
t1_packed = pack_tensor(indices, t1)
t2_packed = pack_tensor(indices, t2)
t3_packed = pack_tensor(indices, t3)
return t1_packed, t2_packed, t3_packed
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Inverse of `pack_tensor`.
"""
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device)
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1,:,:]
return unpacked
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_2_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
return t1_unpacked, t2_unpacked
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
t1: torch.Tensor,
t2: torch.Tensor,
t3: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inverse of `pack_3_tensors`.
"""
t1_unpacked = unpack_tensor(indices, dim0, t1)
t2_unpacked = unpack_tensor(indices, dim0, t2)
t3_unpacked = unpack_tensor(indices, dim0, t3)
return t1_unpacked, t2_unpacked, t3_unpacked
class PackTensors(torch.autograd.Function):
"""
Autograd function to pack tensors.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...]
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported."
ctx.indices = indices
ctx.dim0 = tensors[0].shape[0]
if len(tensors) == 1:
return pack_tensor(indices, *tensors)
if len(tensors) == 2:
return pack_2_tensors(indices, *tensors)
return pack_3_tensors(indices, *tensors)
@staticmethod
def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]):
if len(grad_outputs) == 1:
return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs)
if len(grad_outputs) == 2:
return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs)
return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs)
class UnpackTensor(torch.autograd.Function):
"""
Autograd function to unpack a tensor.
"""
@staticmethod
def forward(
ctx,
indices: torch.Tensor,
dim0: int,
tensor: torch.Tensor,
) -> torch.Tensor:
ctx.indices = indices
return unpack_tensor(indices, dim0, tensor)
@staticmethod
def backward(ctx, grad_output):
return None, None, pack_tensor(ctx.indices, grad_output)
def flash_attn_p2p_communicate(rank, send_tensor, send_dst,
recv_tensor, recv_src,
cp_group, batch_p2p_comm):
"""Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm:
if rank % 2 == 0:
send_op = torch.distributed.P2POp(torch.distributed.isend,
send_tensor,
send_dst,
cp_group)
recv_op = torch.distributed.P2POp(torch.distributed.irecv,
recv_tensor,
recv_src,
cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.P2POp(torch.distributed.irecv,
recv_tensor,
recv_src,
cp_group)
send_op = torch.distributed.P2POp(torch.distributed.isend,
send_tensor,
send_dst,
cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
else:
if rank % 2 == 0:
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = send_recv_ops
return send_recv_reqs
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp
out.add_(out_corrected)
@jit_fuser
def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step):
"""Merge softmax stats of each step in Attention with context parallelism"""
softmax_lse.exp_()
softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp())
softmax_lse.log_()
class AttnFuncWithCP(torch.autograd.Function):
"""
Attention implementation with context parallelism.
Split attention compute into multiple steps, and overlap current-step
compute with next-step communication.
"""
@staticmethod
def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type,
deterministic, use_fused_attention):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size]
recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
causal = (attn_mask_type == "causal")
if causal:
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]
assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8"
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, 0] if causal else [-1, -1]
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
# Flash Attn inputs
q_inputs = [None, None]
kv_inputs = [None, None]
# Flash Attn outputs
out_per_step = [None for _ in range(cp_size)]
softmax_lse_per_step = [None for _ in range(cp_size)]
rng_states = [None for _ in range(cp_size)]
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)]
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []]
for i in range(cp_size+1):
if i < cp_size:
with torch.cuda.stream(flash_attn_streams[i%2]):
# wait until KV is received
for req in send_recv_reqs[(i+1)%2]:
req.wait()
if i < (cp_size-1):
p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
p2p_comm_buffers[i],
send_dst,
p2p_comm_buffers[i+1],
recv_src,
cp_group,
batch_p2p_comm)
kv_inputs[i%2] = p2p_comm_buffers[i]
if causal:
if i == 0:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=True, return_softmax=False,
**fa_optional_forward_kwargs
)
elif i <= rank:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q,
cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()
# [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k//2, max_seqlen_q, max_seqlen_k//2,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
if use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(
2, k.shape[0], -1, *k.shape[-2:])
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2,
cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_forward_kwargs["window_size"] = [-1, -1]
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q//2, cu_seqlens_k, max_seqlen_q//2, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
else:
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \
fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
cu_seqlens_k, q, kv_inputs[i%2][0],
kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask",
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i%2] = q.view(-1, *q.shape[-2:])
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:])
_, _, _, _, out_per_step[i], \
softmax_lse_per_step[i], _, rng_states[i] = _flash_attn_forward(
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1],
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=False, return_softmax=False,
**fa_optional_forward_kwargs
)
if i > 0:
# wait until fwd restuls correction of last step is done
if i > 1:
flash_attn_streams[(i-1)%2].wait_event(fwd_results_correction_done)
if use_fused_attention:
# [b, np, sq, 1] -> [b, np, sq]
softmax_lse_per_step[i-1].squeeze_(-1)
with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2
)
elif (i-1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction(softmax_lse,
softmax_lse_per_step[i-1])
else:
flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :],
softmax_lse_per_step[i-1])
if i < cp_size:
flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done)
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
# [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:])
if i <= rank or not causal:
flash_attn_fwd_out_correction(out.view(*out_.shape),
out_,
softmax_lse,
softmax_lse_per_step[i])
else:
flash_attn_fwd_out_correction(out[:, 1, ...],
out_,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
kv = p2p_comm_buffers[-1]
if use_fused_attention:
out = out.view(out.shape[0], -1, *out.shape[-2:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
@staticmethod
def backward(ctx, dout):
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
if ctx.causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
for i in range(cp_size):
# wait until KV is received
for req in send_recv_reqs:
req.wait()
send_tensor = p2p_comm_buffers[i%2]
recv_tensor = p2p_comm_buffers[(i+1)%2]
if i == 0:
send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0]
if i == (cp_size-1):
send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1]
send_recv_reqs = flash_attn_p2p_communicate(rank,
send_tensor,
send_dst,
recv_tensor,
recv_src,
ctx.cp_group,
batch_p2p_comm)
kv = p2p_comm_buffers[i%2][0]
# In reversed order of fwd
if ctx.causal:
if i == (cp_size-1):
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="causal",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, 0]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, True,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
elif i >= (cp_size-rank-1):
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_ = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous()
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
out_ = out.view(out.shape[0], -1, *out.shape[-2:])
dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:])
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
cu_seqlens_q, cu_seqlens_k//2,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn]
kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2,
ctx.max_seqlen_q, ctx.max_seqlen_k//2,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_ = q[:, 1, ...].contiguous()
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:])
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
out_ = out[:, 1, ...].contiguous()
dout_ = dout[:, 1, ...].contiguous()
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
cu_seqlens_q//2, cu_seqlens_k,
q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype],
[softmax_lse_, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn]
out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:])
dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse_,
dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k,
ctx.max_seqlen_q//2, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
rng_state=ctx.rng_states[cp_size-i-1],
**fa_optional_backward_kwargs
)
else:
if ctx.use_fused_attention:
dq_, dk_, dv_, _ = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_k,
cu_seqlens_q, cu_seqlens_k,
q, kv[0], kv[1], out, dout, TE_DType[q.dtype],
[softmax_lse, ctx.rng_states[cp_size-i-1]],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="no_mask",
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_ = q.view(-1, *q.shape[-2:])
dq_ = torch.empty_like(q_)
# [2, b, sk, np, hn] -> [2, b*sk, np, hn]
kv_ = kv.view(2, -1, *kv.shape[-2:])
dkv_ = torch.empty_like(kv_)
# [b, sq, np, hn] -> [b*sq, np, hn]
out_ = out.view(-1, *out.shape[-2:])
dout_ = dout.view(-1, *dout.shape[-2:])
if _flash_attn_2_3_plus:
fa_optional_backward_kwargs["window_size"] = [-1, -1]
_flash_attn_backward(
dout_, q_, kv_[0], kv_[1], out_, softmax_lse,
dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False,
**fa_optional_backward_kwargs
)
if i >= (cp_size-rank-1) or not ctx.causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
dq_ = dq_.view(*dq.shape)
else:
# [b*sq//2, np, hn] -> [b, sq//2, np, hn]
dq_ = dq_.view(dq.shape[0], *dq.shape[2:])
if ctx.causal:
if i > (cp_size-rank-1):
dq.add_(dq_)
elif i == (cp_size-rank-1):
if rank == (cp_size-1):
dq.copy_(dq_)
else:
dq[:, 0, ...].copy_(dq_[:, 0, ...])
dq[:, 1, ...].add_(dq_[:, 1, ...])
elif i > 0:
dq[:, 1, ...].add_(dq_)
else:
dq[:, 1, ...].copy_(dq_)
else:
if i == 0:
dq.copy_(dq_)
else:
dq.add_(dq_)
# wait until dKV is received
for req in send_recv_reqs:
req.wait()
dkv = p2p_comm_buffers[(i+1)%2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1):
# [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn]
dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:])
else:
# [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
if ctx.causal:
if i == (cp_size-1):
if rank == 0:
dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...])
dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])
else:
dkv.add_(dkv_)
elif i >= (cp_size-rank-1):
if i == 0 and rank == (cp_size-1):
dkv[:, :, 0, ...].copy_(dkv_)