Kernels
danieldk HF Staff commited on
Commit
6c0af68
·
verified ·
1 Parent(s): 1731a09

Build uploaded using `kernels`.

Browse files
build/torch29-cxx11-cu128-x86_64-linux/{_flash_attn3_1d39a44.abi3.so → _flash_attn3_7e34105_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ae6550e9ab9efd6ab826e06037a9140a0f0e5e804da695dff4cb532480faae7
3
- size 779759280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64416ac848e1977a192ae0a8d031036b7f97812be5b7d9d7f5a55a341892c13b
3
+ size 804210888
build/torch29-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn3_1d39a44
3
- ops = torch.ops._flash_attn3_1d39a44
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn3_1d39a44::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn3_7e34105_dirty
3
+ ops = torch.ops._flash_attn3_7e34105_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn3_7e34105_dirty::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/flash_attn_interface.py CHANGED
@@ -1,50 +1,79 @@
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
- from typing import Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
 
8
  from ._ops import ops as flash_attn_3_cuda
 
9
 
10
  def maybe_contiguous(x):
11
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
 
 
48
  q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
  v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
  cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
@@ -56,14 +85,14 @@ def _flash_attn_forward(
56
  ]
57
  rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
  seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
  q,
61
  k,
62
  v,
63
  k_new,
64
  v_new,
65
  qv,
66
- out,
67
  cu_seqlens_q,
68
  cu_seqlens_k,
69
  cu_seqlens_k_new,
@@ -82,8 +111,8 @@ def _flash_attn_forward(
82
  v_descale,
83
  softmax_scale,
84
  causal,
85
- window_size[0],
86
- window_size[1],
87
  attention_chunk,
88
  softcap,
89
  rotary_interleaved,
@@ -92,59 +121,314 @@ def _flash_attn_forward(
92
  pack_gqa,
93
  sm_margin,
94
  )
95
- return out, softmax_lse, *rest
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _flash_attn_backward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  dout,
100
  q,
101
  k,
102
  v,
103
  out,
104
  softmax_lse,
 
 
 
105
  cu_seqlens_q,
106
  cu_seqlens_k,
107
  sequed_q,
108
  sequed_k,
109
  max_seqlen_q,
110
  max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
  softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  dout,
125
  q,
126
  k,
127
  v,
128
  out,
129
  softmax_lse,
 
 
 
130
  dq,
131
  dk,
132
  dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
  )
147
- return dq, dk, dv, softmax_d
 
 
 
 
148
 
149
 
150
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
@@ -161,6 +445,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
161
  deterministic=False,
162
  num_heads_q=None,
163
  sm_margin=0,
 
164
  ):
165
  if softmax_scale is None:
166
  softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -188,7 +473,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
188
  q_descale, k_descale, v_descale,
189
  softmax_scale,
190
  causal=causal,
191
- window_size=window_size,
 
192
  attention_chunk=attention_chunk,
193
  softcap=softcap,
194
  sm_margin=sm_margin,
@@ -203,8 +489,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
203
  ctx.deterministic = deterministic
204
  ctx.ndim = qkv.dim()
205
  ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
 
209
  @staticmethod
210
  def backward(ctx, dout, *args):
@@ -235,13 +520,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
235
  dv,
236
  ctx.softmax_scale,
237
  ctx.causal,
238
- ctx.window_size,
 
239
  ctx.softcap,
240
  ctx.deterministic,
241
  ctx.sm_margin,
242
  )
243
  dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
 
246
 
247
  class FlashAttnFunc(torch.autograd.Function):
@@ -263,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
263
  pack_gqa=None,
264
  deterministic=False,
265
  sm_margin=0,
 
266
  ):
267
  if softmax_scale is None:
268
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -282,7 +569,8 @@ class FlashAttnFunc(torch.autograd.Function):
282
  q_descale, k_descale, v_descale,
283
  softmax_scale,
284
  causal=causal,
285
- window_size=window_size,
 
286
  attention_chunk=attention_chunk,
287
  softcap=softcap,
288
  num_splits=num_splits,
@@ -298,7 +586,7 @@ class FlashAttnFunc(torch.autograd.Function):
298
  ctx.softcap = softcap
299
  ctx.deterministic = deterministic
300
  ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
 
303
  @staticmethod
304
  def backward(ctx, dout, *args):
@@ -320,7 +608,8 @@ class FlashAttnFunc(torch.autograd.Function):
320
  dv,
321
  ctx.softmax_scale,
322
  ctx.causal,
323
- ctx.window_size,
 
324
  ctx.softcap,
325
  ctx.deterministic,
326
  ctx.sm_margin,
@@ -356,6 +645,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
356
  pack_gqa=None,
357
  deterministic=False,
358
  sm_margin=0,
 
359
  ):
360
  if softmax_scale is None:
361
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -379,7 +669,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
379
  q_descale, k_descale, v_descale,
380
  softmax_scale,
381
  causal=causal,
382
- window_size=window_size,
 
383
  attention_chunk=attention_chunk,
384
  softcap=softcap,
385
  num_splits=num_splits,
@@ -397,7 +688,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
397
  ctx.softcap = softcap
398
  ctx.deterministic = deterministic
399
  ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
 
402
  @staticmethod
403
  def backward(ctx, dout, *args):
@@ -422,7 +713,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
422
  dv,
423
  ctx.softmax_scale,
424
  ctx.causal,
425
- ctx.window_size,
 
426
  ctx.softcap,
427
  ctx.deterministic,
428
  ctx.sm_margin,
@@ -444,6 +736,7 @@ def flash_attn_qkvpacked_func(
444
  deterministic=False,
445
  num_heads_q=None,
446
  sm_margin=0,
 
447
  ):
448
  """dropout_p should be set to 0.0 during evaluation
449
  If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -490,6 +783,7 @@ def flash_attn_qkvpacked_func(
490
  deterministic,
491
  num_heads_q,
492
  sm_margin,
 
493
  )
494
 
495
 
@@ -508,6 +802,7 @@ def flash_attn_func(
508
  pack_gqa=None,
509
  deterministic=False,
510
  sm_margin=0,
 
511
  ):
512
  """dropout_p should be set to 0.0 during evaluation
513
  Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -569,6 +864,7 @@ def flash_attn_func(
569
  pack_gqa,
570
  deterministic,
571
  sm_margin,
 
572
  )
573
 
574
 
@@ -593,6 +889,7 @@ def flash_attn_varlen_func(
593
  pack_gqa=None,
594
  deterministic=False,
595
  sm_margin=0,
 
596
  ):
597
  return FlashAttnVarlenFunc.apply(
598
  q,
@@ -615,6 +912,7 @@ def flash_attn_varlen_func(
615
  pack_gqa,
616
  deterministic,
617
  sm_margin,
 
618
  )
619
 
620
 
@@ -700,7 +998,7 @@ def flash_attn_with_kvcache(
700
  q: (batch_size, seqlen, nheads, headdim)
701
  k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
  or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
  v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
  or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
  k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
@@ -745,7 +1043,7 @@ def flash_attn_with_kvcache(
745
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
  if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
  cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
  )
750
  cache_seqlens = maybe_contiguous(cache_seqlens)
751
  out, softmax_lse, *rest = _flash_attn_forward(
@@ -772,7 +1070,8 @@ def flash_attn_with_kvcache(
772
  q_descale, k_descale, v_descale,
773
  softmax_scale,
774
  causal=causal,
775
- window_size=window_size,
 
776
  attention_chunk=attention_chunk,
777
  softcap=softcap,
778
  rotary_interleaved=rotary_interleaved,
 
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
+ from typing import Optional, Union, List, Tuple
4
 
5
  import torch
6
  import torch.nn as nn
7
 
8
  from ._ops import ops as flash_attn_3_cuda
9
+ from ._ops import add_op_namespace_prefix
10
 
11
  def maybe_contiguous(x):
12
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
13
 
14
 
15
+ def round_multiple(x, m):
16
+ return (x + m - 1) // m * m
17
+
18
+
19
+ def round_up_headdim(head_size: int) -> int:
20
+ from .flash_attn_config import CONFIG
21
+
22
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
23
+ if head_size <= 64:
24
+ return 64
25
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
26
+ if head_size <= 96:
27
+ return 96
28
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
29
+ if head_size <= 128:
30
+ return 128
31
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
32
+ if head_size <= 192:
33
+ return 192
34
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
35
+ if head_size <= 256:
36
+ return 256
37
+ return 256
38
+
39
+
40
+ @torch.library.custom_op(add_op_namespace_prefix("_flash_attn_forward"), mutates_args=(), device_types="cuda")
41
  def _flash_attn_forward(
42
+ q: torch.Tensor,
43
+ k: torch.Tensor,
44
+ v: torch.Tensor,
45
+ k_new: Optional[torch.Tensor] = None,
46
+ v_new: Optional[torch.Tensor] = None,
47
+ qv: Optional[torch.Tensor] = None,
48
+ out_: Optional[torch.Tensor] = None,
49
+ cu_seqlens_q: Optional[torch.Tensor] = None,
50
+ cu_seqlens_k: Optional[torch.Tensor] = None,
51
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
52
+ seqused_q: Optional[torch.Tensor] = None,
53
+ seqused_k: Optional[torch.Tensor] = None,
54
+ max_seqlen_q: Optional[int] = None,
55
+ max_seqlen_k: Optional[int] = None,
56
+ page_table: Optional[torch.Tensor] = None,
57
+ kv_batch_idx: Optional[torch.Tensor] = None,
58
+ leftpad_k: Optional[torch.Tensor] = None,
59
+ rotary_cos: Optional[torch.Tensor] = None,
60
+ rotary_sin: Optional[torch.Tensor] = None,
61
+ seqlens_rotary: Optional[torch.Tensor] = None,
62
+ q_descale: Optional[torch.Tensor] = None,
63
+ k_descale: Optional[torch.Tensor] = None,
64
+ v_descale: Optional[torch.Tensor] = None,
65
+ softmax_scale: Optional[float] = None,
66
+ causal: bool = False,
67
+ window_size_left: int = -1,
68
+ window_size_right: int = -1,
69
+ attention_chunk: int = 0,
70
+ softcap: float = 0.0,
71
+ rotary_interleaved: bool = True,
72
+ scheduler_metadata: Optional[torch.Tensor] = None,
73
+ num_splits: int = 1,
74
+ pack_gqa: Optional[bool] = None,
75
+ sm_margin: int = 0,
76
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
77
  q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
78
  v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
79
  cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
 
85
  ]
86
  rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
87
  seqlens_rotary = maybe_contiguous(seqlens_rotary)
88
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
89
  q,
90
  k,
91
  v,
92
  k_new,
93
  v_new,
94
  qv,
95
+ out_,
96
  cu_seqlens_q,
97
  cu_seqlens_k,
98
  cu_seqlens_k_new,
 
111
  v_descale,
112
  softmax_scale,
113
  causal,
114
+ window_size_left,
115
+ window_size_right,
116
  attention_chunk,
117
  softcap,
118
  rotary_interleaved,
 
121
  pack_gqa,
122
  sm_margin,
123
  )
 
124
 
125
+ if out_accum is None:
126
+ out_accum = torch.tensor([], device=out.device)
127
+
128
+ if softmax_lse_accum is None:
129
+ softmax_lse_accum = torch.tensor([], device=out.device)
130
+
131
+ return out, softmax_lse, out_accum, softmax_lse_accum
132
+
133
+
134
+ @torch.library.register_fake(add_op_namespace_prefix("_flash_attn_forward"))
135
+ def _flash_attn_forward_fake(
136
+ q: torch.Tensor,
137
+ k: torch.Tensor,
138
+ v: torch.Tensor,
139
+ k_new: Optional[torch.Tensor] = None,
140
+ v_new: Optional[torch.Tensor] = None,
141
+ qv: Optional[torch.Tensor] = None,
142
+ out_: Optional[torch.Tensor] = None,
143
+ cu_seqlens_q: Optional[torch.Tensor] = None,
144
+ cu_seqlens_k: Optional[torch.Tensor] = None,
145
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
146
+ seqused_q: Optional[torch.Tensor] = None,
147
+ seqused_k: Optional[torch.Tensor] = None,
148
+ max_seqlen_q: Optional[int] = None,
149
+ max_seqlen_k: Optional[int] = None,
150
+ page_table: Optional[torch.Tensor] = None,
151
+ kv_batch_idx: Optional[torch.Tensor] = None,
152
+ leftpad_k: Optional[torch.Tensor] = None,
153
+ rotary_cos: Optional[torch.Tensor] = None,
154
+ rotary_sin: Optional[torch.Tensor] = None,
155
+ seqlens_rotary: Optional[torch.Tensor] = None,
156
+ q_descale: Optional[torch.Tensor] = None,
157
+ k_descale: Optional[torch.Tensor] = None,
158
+ v_descale: Optional[torch.Tensor] = None,
159
+ softmax_scale: Optional[float] = None,
160
+ causal: bool = False,
161
+ window_size_left: int = -1,
162
+ window_size_right: int = -1,
163
+ attention_chunk: int = 0,
164
+ softcap: float = 0.0,
165
+ rotary_interleaved: bool = True,
166
+ scheduler_metadata: Optional[torch.Tensor] = None,
167
+ num_splits: int = 1,
168
+ pack_gqa: Optional[bool] = None,
169
+ sm_margin: int = 0,
170
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """
172
+ Symbolic fake implementation of flash attention forward.
173
+ Returns tensors with the correct shapes and dtypes without actual computation.
174
+ """
175
+
176
+ # Determine if we're in varlen mode
177
+ is_varlen_q = cu_seqlens_q is not None
178
 
179
+ # Get dimensions from query tensor
180
+ if is_varlen_q:
181
+ # varlen mode: q is (total_q, num_heads, head_size)
182
+ total_q, num_heads, head_size = q.shape
183
+ batch_size = cu_seqlens_q.shape[0] - 1
184
+
185
+ if max_seqlen_q is None:
186
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
187
+ seqlen_q = max_seqlen_q
188
+ else:
189
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
190
+ batch_size, seqlen_q, num_heads, head_size = q.shape
191
+ total_q = batch_size * q.shape[1]
192
+ # Get value head dimension
193
+ head_size_v = v.shape[-1]
194
+
195
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
196
+ q_type = q.dtype
197
+ if q_type == torch.float8_e4m3fn:
198
+ out_dtype = torch.bfloat16
199
+ else:
200
+ out_dtype = q_type
201
+
202
+ # Create output tensor
203
+ if out_ is not None:
204
+ # If out_ is provided, _flash_attn_forward becomes non-functional
205
+ raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.")
206
+
207
+ if is_varlen_q:
208
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
209
+ else:
210
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+
212
+ # Create softmax_lse tensor
213
+ if is_varlen_q:
214
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
215
+ else:
216
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
217
+
218
+ # TODO(guilhermeleobas): Implement "get_num_splits"
219
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
220
+ # assert that num_splits is > 0 for now
221
+ if num_splits <= 0:
222
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
223
+
224
+ if num_splits > 1:
225
+ if is_varlen_q:
226
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
227
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
228
+ else:
229
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
230
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
231
+ else:
232
+ # Tensors are not set when num_splits < 1
233
+ out_accum = torch.tensor([], device=out.device)
234
+ softmax_lse_accum = torch.tensor([], device=out.device)
235
+
236
+ return out, softmax_lse, out_accum, softmax_lse_accum
237
+
238
+
239
+ @torch.library.custom_op(add_op_namespace_prefix("_flash_attn_backward"), mutates_args=("dq", "dk", "dv"), device_types="cuda")
240
  def _flash_attn_backward(
241
+ dout: torch.Tensor,
242
+ q: torch.Tensor,
243
+ k: torch.Tensor,
244
+ v: torch.Tensor,
245
+ out: torch.Tensor,
246
+ softmax_lse: torch.Tensor,
247
+ cu_seqlens_q: Optional[torch.Tensor] = None,
248
+ cu_seqlens_k: Optional[torch.Tensor] = None,
249
+ sequed_q: Optional[torch.Tensor] = None,
250
+ sequed_k: Optional[torch.Tensor] = None,
251
+ max_seqlen_q: Optional[int] = None,
252
+ max_seqlen_k: Optional[int] = None,
253
+ dq: Optional[torch.Tensor] = None,
254
+ dk: Optional[torch.Tensor] = None,
255
+ dv: Optional[torch.Tensor] = None,
256
+ softmax_scale: Optional[float] = None,
257
+ is_causal: bool = False,
258
+ window_size_left: int = -1,
259
+ window_size_right: int = -1,
260
+ softcap: float = 0.0,
261
+ deterministic: bool = False,
262
+ sm_margin: int = 0,
263
+ ) -> torch.Tensor:
264
+ # dq, dk, dv are allocated by us so they should already be contiguous
265
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
266
+ softmax_d, *rest = flash_attn_3_cuda.bwd(
267
  dout,
268
  q,
269
  k,
270
  v,
271
  out,
272
  softmax_lse,
273
+ dq,
274
+ dk,
275
+ dv,
276
  cu_seqlens_q,
277
  cu_seqlens_k,
278
  sequed_q,
279
  sequed_k,
280
  max_seqlen_q,
281
  max_seqlen_k,
 
 
 
282
  softmax_scale,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ sm_margin,
289
+ )
290
+ return softmax_d
291
+
292
+
293
+ @torch.library.register_fake(add_op_namespace_prefix("_flash_attn_backward"))
294
+ def _flash_attn_backward_fake(
295
+ dout: torch.Tensor,
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ out: torch.Tensor,
300
+ softmax_lse: torch.Tensor,
301
+ cu_seqlens_q: Optional[torch.Tensor] = None,
302
+ cu_seqlens_k: Optional[torch.Tensor] = None,
303
+ sequed_q: Optional[torch.Tensor] = None,
304
+ sequed_k: Optional[torch.Tensor] = None,
305
+ max_seqlen_q: Optional[int] = None,
306
+ max_seqlen_k: Optional[int] = None,
307
+ dq: Optional[torch.Tensor] = None,
308
+ dk: Optional[torch.Tensor] = None,
309
+ dv: Optional[torch.Tensor] = None,
310
+ softmax_scale: Optional[float] = None,
311
+ is_causal: bool = False,
312
+ window_size_left: int = -1,
313
+ window_size_right: int = -1,
314
+ softcap: float = 0.0,
315
+ deterministic: bool = False,
316
+ sm_margin: int = 0,
317
+ ) -> torch.Tensor:
318
+
319
+ is_varlen_q = cu_seqlens_q is not None
320
+ is_varlen_k = cu_seqlens_q is not None
321
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
322
+
323
+ if not is_varlen_q:
324
+ batch_size = q.size(0)
325
+ seqlen_q = q.size(1)
326
+ seqlen_k = k.size(1)
327
+ total_q = batch_size * q.size(1)
328
+ else:
329
+ batch_size = cu_seqlens_q.size(0) - 1
330
+ total_q = q.size(0)
331
+ seqlen_q = max_seqlen_q
332
+ seqlen_k = max_seqlen_k
333
+
334
+ if window_size_left >= seqlen_k - 1:
335
+ window_size_left = -1
336
+
337
+ if window_size_right >= seqlen_q - 1:
338
+ window_size_right = -1
339
+
340
+ if is_causal:
341
+ window_size_right = 0
342
+
343
+ is_causal = window_size_left < 0 and window_size_right == 0
344
+
345
+ head_size = q.size(-1)
346
+ head_size_v = v.size(-1)
347
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
348
+
349
+ # Hopper gpus uses cuda compute capabilities 9.0
350
+ cap = torch.cuda.get_device_capability(q.device)
351
+ arch = cap[0] * 10 + cap[1]
352
+
353
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
354
+
355
+ if head_size_rounded <= 64:
356
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
357
+ elif head_size_rounded <= 96:
358
+ kBlockM_sm90 = 64
359
+ elif head_size_rounded <= 128:
360
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
361
+ else:
362
+ kBlockM_sm90 = 64
363
+
364
+ kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64
365
+ kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32
366
+
367
+ if arch >= 90:
368
+ kBlockM = kBlockM_sm90
369
+ elif arch == 86 or arch == 89:
370
+ kBlockM = kBlockM_sm86
371
+ else:
372
+ kBlockM = kBlockM_sm80
373
+
374
+ num_heads = q.shape[-2]
375
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
376
+
377
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
378
+
379
+ dq = torch.empty_like(q) if dq is None else dq
380
+ dk = torch.empty_like(k) if dk is None else dk
381
+ dv = torch.empty_like(v) if dv is None else dv
382
+
383
+ if not is_varlen:
384
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
385
+ else:
386
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
387
+
388
+ return softmax_d
389
+
390
+
391
+ def setup_context(ctx, inputs, output):
392
+ q, k, v = inputs[:3]
393
+ out, softmax_lse, _, _ = output
394
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
395
+ ctx.softmax_scale = inputs[-11]
396
+ ctx.causal = inputs[-10]
397
+ ctx.window_size = [inputs[-9], inputs[-8]]
398
+ ctx.attention_chunk = inputs[-7]
399
+ ctx.softcap = inputs[-6]
400
+ ctx.sm_margin = inputs[-1]
401
+
402
+
403
+ def _backward(ctx, dout, *grads):
404
+ q, k, v, out, softmax_lse = ctx.saved_tensors
405
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
406
+ _flash_attn_backward(
407
  dout,
408
  q,
409
  k,
410
  v,
411
  out,
412
  softmax_lse,
413
+ None, None, # cu_seqlens_q, cu_seqlens_k,
414
+ None, None, # sequed_q, sequed_k,
415
+ None, None, # max_seqlen_q, max_seqlen_k,
416
  dq,
417
  dk,
418
  dv,
419
+ ctx.softmax_scale,
420
+ ctx.causal,
421
+ ctx.window_size[0],
422
+ ctx.window_size[1],
423
+ ctx.softcap,
424
+ False, # deterministic
425
+ ctx.sm_margin,
 
 
 
 
 
 
426
  )
427
+ return dq, dk, dv, *((None,) * 21)
428
+
429
+
430
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
431
+
432
 
433
 
434
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
 
445
  deterministic=False,
446
  num_heads_q=None,
447
  sm_margin=0,
448
+ return_softmax=False,
449
  ):
450
  if softmax_scale is None:
451
  softmax_scale = qkv.shape[-1] ** (-0.5)
 
473
  q_descale, k_descale, v_descale,
474
  softmax_scale,
475
  causal=causal,
476
+ window_size_left=window_size[0],
477
+ window_size_right=window_size[1],
478
  attention_chunk=attention_chunk,
479
  softcap=softcap,
480
  sm_margin=sm_margin,
 
489
  ctx.deterministic = deterministic
490
  ctx.ndim = qkv.dim()
491
  ctx.sm_margin = sm_margin
492
+ return (out, softmax_lse) if return_softmax else out
 
493
 
494
  @staticmethod
495
  def backward(ctx, dout, *args):
 
520
  dv,
521
  ctx.softmax_scale,
522
  ctx.causal,
523
+ ctx.window_size[0],
524
+ ctx.window_size[1],
525
  ctx.softcap,
526
  ctx.deterministic,
527
  ctx.sm_margin,
528
  )
529
  dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
530
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None, None
531
 
532
 
533
  class FlashAttnFunc(torch.autograd.Function):
 
549
  pack_gqa=None,
550
  deterministic=False,
551
  sm_margin=0,
552
+ return_softmax=False,
553
  ):
554
  if softmax_scale is None:
555
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
 
569
  q_descale, k_descale, v_descale,
570
  softmax_scale,
571
  causal=causal,
572
+ window_size_left=window_size[0],
573
+ window_size_right=window_size[1],
574
  attention_chunk=attention_chunk,
575
  softcap=softcap,
576
  num_splits=num_splits,
 
586
  ctx.softcap = softcap
587
  ctx.deterministic = deterministic
588
  ctx.sm_margin = sm_margin
589
+ return (out, softmax_lse) if return_softmax else out
590
 
591
  @staticmethod
592
  def backward(ctx, dout, *args):
 
608
  dv,
609
  ctx.softmax_scale,
610
  ctx.causal,
611
+ ctx.window_size[0],
612
+ ctx.window_size[1],
613
  ctx.softcap,
614
  ctx.deterministic,
615
  ctx.sm_margin,
 
645
  pack_gqa=None,
646
  deterministic=False,
647
  sm_margin=0,
648
+ return_softmax=False,
649
  ):
650
  if softmax_scale is None:
651
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
 
669
  q_descale, k_descale, v_descale,
670
  softmax_scale,
671
  causal=causal,
672
+ window_size_left=window_size[0],
673
+ window_size_right=window_size[1],
674
  attention_chunk=attention_chunk,
675
  softcap=softcap,
676
  num_splits=num_splits,
 
688
  ctx.softcap = softcap
689
  ctx.deterministic = deterministic
690
  ctx.sm_margin = sm_margin
691
+ return (out, softmax_lse) if return_softmax else out
692
 
693
  @staticmethod
694
  def backward(ctx, dout, *args):
 
713
  dv,
714
  ctx.softmax_scale,
715
  ctx.causal,
716
+ ctx.window_size[0],
717
+ ctx.window_size[1],
718
  ctx.softcap,
719
  ctx.deterministic,
720
  ctx.sm_margin,
 
736
  deterministic=False,
737
  num_heads_q=None,
738
  sm_margin=0,
739
+ return_attn_probs=False,
740
  ):
741
  """dropout_p should be set to 0.0 during evaluation
742
  If Q, K, V are already stacked into 1 tensor, this function will be faster than
 
783
  deterministic,
784
  num_heads_q,
785
  sm_margin,
786
+ return_attn_probs,
787
  )
788
 
789
 
 
802
  pack_gqa=None,
803
  deterministic=False,
804
  sm_margin=0,
805
+ return_attn_probs=False,
806
  ):
807
  """dropout_p should be set to 0.0 during evaluation
808
  Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
 
864
  pack_gqa,
865
  deterministic,
866
  sm_margin,
867
+ return_attn_probs,
868
  )
869
 
870
 
 
889
  pack_gqa=None,
890
  deterministic=False,
891
  sm_margin=0,
892
+ return_attn_probs=False,
893
  ):
894
  return FlashAttnVarlenFunc.apply(
895
  q,
 
912
  pack_gqa,
913
  deterministic,
914
  sm_margin,
915
+ return_attn_probs,
916
  )
917
 
918
 
 
998
  q: (batch_size, seqlen, nheads, headdim)
999
  k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
1000
  or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
1001
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
1002
  v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
1003
  or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
1004
  k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
 
1043
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1044
  if cache_seqlens is not None and isinstance(cache_seqlens, int):
1045
  cache_seqlens = torch.full(
1046
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1047
  )
1048
  cache_seqlens = maybe_contiguous(cache_seqlens)
1049
  out, softmax_lse, *rest = _flash_attn_forward(
 
1070
  q_descale, k_descale, v_descale,
1071
  softmax_scale,
1072
  causal=causal,
1073
+ window_size_left=window_size[0],
1074
+ window_size_right=window_size[1],
1075
  attention_chunk=attention_chunk,
1076
  softcap=softcap,
1077
  rotary_interleaved=rotary_interleaved,
build/torch29-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch29-cxx11-cu130-x86_64-linux/{_flash_attn3_1d39a44.abi3.so → _flash_attn3_7e34105_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:be71ff231157cea9508af9835b6ed814f5010ed159b073c74ef660ad5d3885f3
3
- size 795878440
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b71000e0cda6d5dd2f967eb990bdc6d41db679a868bb9f2f96fdfc6b0caa461
3
+ size 823719112
build/torch29-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn3_1d39a44
3
- ops = torch.ops._flash_attn3_1d39a44
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn3_1d39a44::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn3_7e34105_dirty
3
+ ops = torch.ops._flash_attn3_7e34105_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn3_7e34105_dirty::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/flash_attn_interface.py CHANGED
@@ -1,50 +1,79 @@
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
- from typing import Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
 
8
  from ._ops import ops as flash_attn_3_cuda
 
9
 
10
  def maybe_contiguous(x):
11
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
 
 
48
  q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
  v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
  cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
@@ -56,14 +85,14 @@ def _flash_attn_forward(
56
  ]
57
  rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
  seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
  q,
61
  k,
62
  v,
63
  k_new,
64
  v_new,
65
  qv,
66
- out,
67
  cu_seqlens_q,
68
  cu_seqlens_k,
69
  cu_seqlens_k_new,
@@ -82,8 +111,8 @@ def _flash_attn_forward(
82
  v_descale,
83
  softmax_scale,
84
  causal,
85
- window_size[0],
86
- window_size[1],
87
  attention_chunk,
88
  softcap,
89
  rotary_interleaved,
@@ -92,59 +121,314 @@ def _flash_attn_forward(
92
  pack_gqa,
93
  sm_margin,
94
  )
95
- return out, softmax_lse, *rest
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _flash_attn_backward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  dout,
100
  q,
101
  k,
102
  v,
103
  out,
104
  softmax_lse,
 
 
 
105
  cu_seqlens_q,
106
  cu_seqlens_k,
107
  sequed_q,
108
  sequed_k,
109
  max_seqlen_q,
110
  max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
  softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  dout,
125
  q,
126
  k,
127
  v,
128
  out,
129
  softmax_lse,
 
 
 
130
  dq,
131
  dk,
132
  dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
  )
147
- return dq, dk, dv, softmax_d
 
 
 
 
148
 
149
 
150
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
@@ -161,6 +445,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
161
  deterministic=False,
162
  num_heads_q=None,
163
  sm_margin=0,
 
164
  ):
165
  if softmax_scale is None:
166
  softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -188,7 +473,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
188
  q_descale, k_descale, v_descale,
189
  softmax_scale,
190
  causal=causal,
191
- window_size=window_size,
 
192
  attention_chunk=attention_chunk,
193
  softcap=softcap,
194
  sm_margin=sm_margin,
@@ -203,8 +489,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
203
  ctx.deterministic = deterministic
204
  ctx.ndim = qkv.dim()
205
  ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
 
209
  @staticmethod
210
  def backward(ctx, dout, *args):
@@ -235,13 +520,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
235
  dv,
236
  ctx.softmax_scale,
237
  ctx.causal,
238
- ctx.window_size,
 
239
  ctx.softcap,
240
  ctx.deterministic,
241
  ctx.sm_margin,
242
  )
243
  dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
 
246
 
247
  class FlashAttnFunc(torch.autograd.Function):
@@ -263,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
263
  pack_gqa=None,
264
  deterministic=False,
265
  sm_margin=0,
 
266
  ):
267
  if softmax_scale is None:
268
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -282,7 +569,8 @@ class FlashAttnFunc(torch.autograd.Function):
282
  q_descale, k_descale, v_descale,
283
  softmax_scale,
284
  causal=causal,
285
- window_size=window_size,
 
286
  attention_chunk=attention_chunk,
287
  softcap=softcap,
288
  num_splits=num_splits,
@@ -298,7 +586,7 @@ class FlashAttnFunc(torch.autograd.Function):
298
  ctx.softcap = softcap
299
  ctx.deterministic = deterministic
300
  ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
 
303
  @staticmethod
304
  def backward(ctx, dout, *args):
@@ -320,7 +608,8 @@ class FlashAttnFunc(torch.autograd.Function):
320
  dv,
321
  ctx.softmax_scale,
322
  ctx.causal,
323
- ctx.window_size,
 
324
  ctx.softcap,
325
  ctx.deterministic,
326
  ctx.sm_margin,
@@ -356,6 +645,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
356
  pack_gqa=None,
357
  deterministic=False,
358
  sm_margin=0,
 
359
  ):
360
  if softmax_scale is None:
361
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -379,7 +669,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
379
  q_descale, k_descale, v_descale,
380
  softmax_scale,
381
  causal=causal,
382
- window_size=window_size,
 
383
  attention_chunk=attention_chunk,
384
  softcap=softcap,
385
  num_splits=num_splits,
@@ -397,7 +688,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
397
  ctx.softcap = softcap
398
  ctx.deterministic = deterministic
399
  ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
 
402
  @staticmethod
403
  def backward(ctx, dout, *args):
@@ -422,7 +713,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
422
  dv,
423
  ctx.softmax_scale,
424
  ctx.causal,
425
- ctx.window_size,
 
426
  ctx.softcap,
427
  ctx.deterministic,
428
  ctx.sm_margin,
@@ -444,6 +736,7 @@ def flash_attn_qkvpacked_func(
444
  deterministic=False,
445
  num_heads_q=None,
446
  sm_margin=0,
 
447
  ):
448
  """dropout_p should be set to 0.0 during evaluation
449
  If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -490,6 +783,7 @@ def flash_attn_qkvpacked_func(
490
  deterministic,
491
  num_heads_q,
492
  sm_margin,
 
493
  )
494
 
495
 
@@ -508,6 +802,7 @@ def flash_attn_func(
508
  pack_gqa=None,
509
  deterministic=False,
510
  sm_margin=0,
 
511
  ):
512
  """dropout_p should be set to 0.0 during evaluation
513
  Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -569,6 +864,7 @@ def flash_attn_func(
569
  pack_gqa,
570
  deterministic,
571
  sm_margin,
 
572
  )
573
 
574
 
@@ -593,6 +889,7 @@ def flash_attn_varlen_func(
593
  pack_gqa=None,
594
  deterministic=False,
595
  sm_margin=0,
 
596
  ):
597
  return FlashAttnVarlenFunc.apply(
598
  q,
@@ -615,6 +912,7 @@ def flash_attn_varlen_func(
615
  pack_gqa,
616
  deterministic,
617
  sm_margin,
 
618
  )
619
 
620
 
@@ -700,7 +998,7 @@ def flash_attn_with_kvcache(
700
  q: (batch_size, seqlen, nheads, headdim)
701
  k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
  or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
  v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
  or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
  k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
@@ -745,7 +1043,7 @@ def flash_attn_with_kvcache(
745
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
  if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
  cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
  )
750
  cache_seqlens = maybe_contiguous(cache_seqlens)
751
  out, softmax_lse, *rest = _flash_attn_forward(
@@ -772,7 +1070,8 @@ def flash_attn_with_kvcache(
772
  q_descale, k_descale, v_descale,
773
  softmax_scale,
774
  causal=causal,
775
- window_size=window_size,
 
776
  attention_chunk=attention_chunk,
777
  softcap=softcap,
778
  rotary_interleaved=rotary_interleaved,
 
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
+ from typing import Optional, Union, List, Tuple
4
 
5
  import torch
6
  import torch.nn as nn
7
 
8
  from ._ops import ops as flash_attn_3_cuda
9
+ from ._ops import add_op_namespace_prefix
10
 
11
  def maybe_contiguous(x):
12
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
13
 
14
 
15
+ def round_multiple(x, m):
16
+ return (x + m - 1) // m * m
17
+
18
+
19
+ def round_up_headdim(head_size: int) -> int:
20
+ from .flash_attn_config import CONFIG
21
+
22
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
23
+ if head_size <= 64:
24
+ return 64
25
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
26
+ if head_size <= 96:
27
+ return 96
28
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
29
+ if head_size <= 128:
30
+ return 128
31
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
32
+ if head_size <= 192:
33
+ return 192
34
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
35
+ if head_size <= 256:
36
+ return 256
37
+ return 256
38
+
39
+
40
+ @torch.library.custom_op(add_op_namespace_prefix("_flash_attn_forward"), mutates_args=(), device_types="cuda")
41
  def _flash_attn_forward(
42
+ q: torch.Tensor,
43
+ k: torch.Tensor,
44
+ v: torch.Tensor,
45
+ k_new: Optional[torch.Tensor] = None,
46
+ v_new: Optional[torch.Tensor] = None,
47
+ qv: Optional[torch.Tensor] = None,
48
+ out_: Optional[torch.Tensor] = None,
49
+ cu_seqlens_q: Optional[torch.Tensor] = None,
50
+ cu_seqlens_k: Optional[torch.Tensor] = None,
51
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
52
+ seqused_q: Optional[torch.Tensor] = None,
53
+ seqused_k: Optional[torch.Tensor] = None,
54
+ max_seqlen_q: Optional[int] = None,
55
+ max_seqlen_k: Optional[int] = None,
56
+ page_table: Optional[torch.Tensor] = None,
57
+ kv_batch_idx: Optional[torch.Tensor] = None,
58
+ leftpad_k: Optional[torch.Tensor] = None,
59
+ rotary_cos: Optional[torch.Tensor] = None,
60
+ rotary_sin: Optional[torch.Tensor] = None,
61
+ seqlens_rotary: Optional[torch.Tensor] = None,
62
+ q_descale: Optional[torch.Tensor] = None,
63
+ k_descale: Optional[torch.Tensor] = None,
64
+ v_descale: Optional[torch.Tensor] = None,
65
+ softmax_scale: Optional[float] = None,
66
+ causal: bool = False,
67
+ window_size_left: int = -1,
68
+ window_size_right: int = -1,
69
+ attention_chunk: int = 0,
70
+ softcap: float = 0.0,
71
+ rotary_interleaved: bool = True,
72
+ scheduler_metadata: Optional[torch.Tensor] = None,
73
+ num_splits: int = 1,
74
+ pack_gqa: Optional[bool] = None,
75
+ sm_margin: int = 0,
76
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
77
  q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
78
  v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
79
  cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
 
85
  ]
86
  rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
87
  seqlens_rotary = maybe_contiguous(seqlens_rotary)
88
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
89
  q,
90
  k,
91
  v,
92
  k_new,
93
  v_new,
94
  qv,
95
+ out_,
96
  cu_seqlens_q,
97
  cu_seqlens_k,
98
  cu_seqlens_k_new,
 
111
  v_descale,
112
  softmax_scale,
113
  causal,
114
+ window_size_left,
115
+ window_size_right,
116
  attention_chunk,
117
  softcap,
118
  rotary_interleaved,
 
121
  pack_gqa,
122
  sm_margin,
123
  )
 
124
 
125
+ if out_accum is None:
126
+ out_accum = torch.tensor([], device=out.device)
127
+
128
+ if softmax_lse_accum is None:
129
+ softmax_lse_accum = torch.tensor([], device=out.device)
130
+
131
+ return out, softmax_lse, out_accum, softmax_lse_accum
132
+
133
+
134
+ @torch.library.register_fake(add_op_namespace_prefix("_flash_attn_forward"))
135
+ def _flash_attn_forward_fake(
136
+ q: torch.Tensor,
137
+ k: torch.Tensor,
138
+ v: torch.Tensor,
139
+ k_new: Optional[torch.Tensor] = None,
140
+ v_new: Optional[torch.Tensor] = None,
141
+ qv: Optional[torch.Tensor] = None,
142
+ out_: Optional[torch.Tensor] = None,
143
+ cu_seqlens_q: Optional[torch.Tensor] = None,
144
+ cu_seqlens_k: Optional[torch.Tensor] = None,
145
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
146
+ seqused_q: Optional[torch.Tensor] = None,
147
+ seqused_k: Optional[torch.Tensor] = None,
148
+ max_seqlen_q: Optional[int] = None,
149
+ max_seqlen_k: Optional[int] = None,
150
+ page_table: Optional[torch.Tensor] = None,
151
+ kv_batch_idx: Optional[torch.Tensor] = None,
152
+ leftpad_k: Optional[torch.Tensor] = None,
153
+ rotary_cos: Optional[torch.Tensor] = None,
154
+ rotary_sin: Optional[torch.Tensor] = None,
155
+ seqlens_rotary: Optional[torch.Tensor] = None,
156
+ q_descale: Optional[torch.Tensor] = None,
157
+ k_descale: Optional[torch.Tensor] = None,
158
+ v_descale: Optional[torch.Tensor] = None,
159
+ softmax_scale: Optional[float] = None,
160
+ causal: bool = False,
161
+ window_size_left: int = -1,
162
+ window_size_right: int = -1,
163
+ attention_chunk: int = 0,
164
+ softcap: float = 0.0,
165
+ rotary_interleaved: bool = True,
166
+ scheduler_metadata: Optional[torch.Tensor] = None,
167
+ num_splits: int = 1,
168
+ pack_gqa: Optional[bool] = None,
169
+ sm_margin: int = 0,
170
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """
172
+ Symbolic fake implementation of flash attention forward.
173
+ Returns tensors with the correct shapes and dtypes without actual computation.
174
+ """
175
+
176
+ # Determine if we're in varlen mode
177
+ is_varlen_q = cu_seqlens_q is not None
178
 
179
+ # Get dimensions from query tensor
180
+ if is_varlen_q:
181
+ # varlen mode: q is (total_q, num_heads, head_size)
182
+ total_q, num_heads, head_size = q.shape
183
+ batch_size = cu_seqlens_q.shape[0] - 1
184
+
185
+ if max_seqlen_q is None:
186
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
187
+ seqlen_q = max_seqlen_q
188
+ else:
189
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
190
+ batch_size, seqlen_q, num_heads, head_size = q.shape
191
+ total_q = batch_size * q.shape[1]
192
+ # Get value head dimension
193
+ head_size_v = v.shape[-1]
194
+
195
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
196
+ q_type = q.dtype
197
+ if q_type == torch.float8_e4m3fn:
198
+ out_dtype = torch.bfloat16
199
+ else:
200
+ out_dtype = q_type
201
+
202
+ # Create output tensor
203
+ if out_ is not None:
204
+ # If out_ is provided, _flash_attn_forward becomes non-functional
205
+ raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.")
206
+
207
+ if is_varlen_q:
208
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
209
+ else:
210
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+
212
+ # Create softmax_lse tensor
213
+ if is_varlen_q:
214
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
215
+ else:
216
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
217
+
218
+ # TODO(guilhermeleobas): Implement "get_num_splits"
219
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
220
+ # assert that num_splits is > 0 for now
221
+ if num_splits <= 0:
222
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
223
+
224
+ if num_splits > 1:
225
+ if is_varlen_q:
226
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
227
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
228
+ else:
229
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
230
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
231
+ else:
232
+ # Tensors are not set when num_splits < 1
233
+ out_accum = torch.tensor([], device=out.device)
234
+ softmax_lse_accum = torch.tensor([], device=out.device)
235
+
236
+ return out, softmax_lse, out_accum, softmax_lse_accum
237
+
238
+
239
+ @torch.library.custom_op(add_op_namespace_prefix("_flash_attn_backward"), mutates_args=("dq", "dk", "dv"), device_types="cuda")
240
  def _flash_attn_backward(
241
+ dout: torch.Tensor,
242
+ q: torch.Tensor,
243
+ k: torch.Tensor,
244
+ v: torch.Tensor,
245
+ out: torch.Tensor,
246
+ softmax_lse: torch.Tensor,
247
+ cu_seqlens_q: Optional[torch.Tensor] = None,
248
+ cu_seqlens_k: Optional[torch.Tensor] = None,
249
+ sequed_q: Optional[torch.Tensor] = None,
250
+ sequed_k: Optional[torch.Tensor] = None,
251
+ max_seqlen_q: Optional[int] = None,
252
+ max_seqlen_k: Optional[int] = None,
253
+ dq: Optional[torch.Tensor] = None,
254
+ dk: Optional[torch.Tensor] = None,
255
+ dv: Optional[torch.Tensor] = None,
256
+ softmax_scale: Optional[float] = None,
257
+ is_causal: bool = False,
258
+ window_size_left: int = -1,
259
+ window_size_right: int = -1,
260
+ softcap: float = 0.0,
261
+ deterministic: bool = False,
262
+ sm_margin: int = 0,
263
+ ) -> torch.Tensor:
264
+ # dq, dk, dv are allocated by us so they should already be contiguous
265
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
266
+ softmax_d, *rest = flash_attn_3_cuda.bwd(
267
  dout,
268
  q,
269
  k,
270
  v,
271
  out,
272
  softmax_lse,
273
+ dq,
274
+ dk,
275
+ dv,
276
  cu_seqlens_q,
277
  cu_seqlens_k,
278
  sequed_q,
279
  sequed_k,
280
  max_seqlen_q,
281
  max_seqlen_k,
 
 
 
282
  softmax_scale,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ sm_margin,
289
+ )
290
+ return softmax_d
291
+
292
+
293
+ @torch.library.register_fake(add_op_namespace_prefix("_flash_attn_backward"))
294
+ def _flash_attn_backward_fake(
295
+ dout: torch.Tensor,
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ out: torch.Tensor,
300
+ softmax_lse: torch.Tensor,
301
+ cu_seqlens_q: Optional[torch.Tensor] = None,
302
+ cu_seqlens_k: Optional[torch.Tensor] = None,
303
+ sequed_q: Optional[torch.Tensor] = None,
304
+ sequed_k: Optional[torch.Tensor] = None,
305
+ max_seqlen_q: Optional[int] = None,
306
+ max_seqlen_k: Optional[int] = None,
307
+ dq: Optional[torch.Tensor] = None,
308
+ dk: Optional[torch.Tensor] = None,
309
+ dv: Optional[torch.Tensor] = None,
310
+ softmax_scale: Optional[float] = None,
311
+ is_causal: bool = False,
312
+ window_size_left: int = -1,
313
+ window_size_right: int = -1,
314
+ softcap: float = 0.0,
315
+ deterministic: bool = False,
316
+ sm_margin: int = 0,
317
+ ) -> torch.Tensor:
318
+
319
+ is_varlen_q = cu_seqlens_q is not None
320
+ is_varlen_k = cu_seqlens_q is not None
321
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
322
+
323
+ if not is_varlen_q:
324
+ batch_size = q.size(0)
325
+ seqlen_q = q.size(1)
326
+ seqlen_k = k.size(1)
327
+ total_q = batch_size * q.size(1)
328
+ else:
329
+ batch_size = cu_seqlens_q.size(0) - 1
330
+ total_q = q.size(0)
331
+ seqlen_q = max_seqlen_q
332
+ seqlen_k = max_seqlen_k
333
+
334
+ if window_size_left >= seqlen_k - 1:
335
+ window_size_left = -1
336
+
337
+ if window_size_right >= seqlen_q - 1:
338
+ window_size_right = -1
339
+
340
+ if is_causal:
341
+ window_size_right = 0
342
+
343
+ is_causal = window_size_left < 0 and window_size_right == 0
344
+
345
+ head_size = q.size(-1)
346
+ head_size_v = v.size(-1)
347
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
348
+
349
+ # Hopper gpus uses cuda compute capabilities 9.0
350
+ cap = torch.cuda.get_device_capability(q.device)
351
+ arch = cap[0] * 10 + cap[1]
352
+
353
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
354
+
355
+ if head_size_rounded <= 64:
356
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
357
+ elif head_size_rounded <= 96:
358
+ kBlockM_sm90 = 64
359
+ elif head_size_rounded <= 128:
360
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
361
+ else:
362
+ kBlockM_sm90 = 64
363
+
364
+ kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64
365
+ kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32
366
+
367
+ if arch >= 90:
368
+ kBlockM = kBlockM_sm90
369
+ elif arch == 86 or arch == 89:
370
+ kBlockM = kBlockM_sm86
371
+ else:
372
+ kBlockM = kBlockM_sm80
373
+
374
+ num_heads = q.shape[-2]
375
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
376
+
377
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
378
+
379
+ dq = torch.empty_like(q) if dq is None else dq
380
+ dk = torch.empty_like(k) if dk is None else dk
381
+ dv = torch.empty_like(v) if dv is None else dv
382
+
383
+ if not is_varlen:
384
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
385
+ else:
386
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
387
+
388
+ return softmax_d
389
+
390
+
391
+ def setup_context(ctx, inputs, output):
392
+ q, k, v = inputs[:3]
393
+ out, softmax_lse, _, _ = output
394
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
395
+ ctx.softmax_scale = inputs[-11]
396
+ ctx.causal = inputs[-10]
397
+ ctx.window_size = [inputs[-9], inputs[-8]]
398
+ ctx.attention_chunk = inputs[-7]
399
+ ctx.softcap = inputs[-6]
400
+ ctx.sm_margin = inputs[-1]
401
+
402
+
403
+ def _backward(ctx, dout, *grads):
404
+ q, k, v, out, softmax_lse = ctx.saved_tensors
405
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
406
+ _flash_attn_backward(
407
  dout,
408
  q,
409
  k,
410
  v,
411
  out,
412
  softmax_lse,
413
+ None, None, # cu_seqlens_q, cu_seqlens_k,
414
+ None, None, # sequed_q, sequed_k,
415
+ None, None, # max_seqlen_q, max_seqlen_k,
416
  dq,
417
  dk,
418
  dv,
419
+ ctx.softmax_scale,
420
+ ctx.causal,
421
+ ctx.window_size[0],
422
+ ctx.window_size[1],
423
+ ctx.softcap,
424
+ False, # deterministic
425
+ ctx.sm_margin,
 
 
 
 
 
 
426
  )
427
+ return dq, dk, dv, *((None,) * 21)
428
+
429
+
430
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
431
+
432
 
433
 
434
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
 
445
  deterministic=False,
446
  num_heads_q=None,
447
  sm_margin=0,
448
+ return_softmax=False,
449
  ):
450
  if softmax_scale is None:
451
  softmax_scale = qkv.shape[-1] ** (-0.5)
 
473
  q_descale, k_descale, v_descale,
474
  softmax_scale,
475
  causal=causal,
476
+ window_size_left=window_size[0],
477
+ window_size_right=window_size[1],
478
  attention_chunk=attention_chunk,
479
  softcap=softcap,
480
  sm_margin=sm_margin,
 
489
  ctx.deterministic = deterministic
490
  ctx.ndim = qkv.dim()
491
  ctx.sm_margin = sm_margin
492
+ return (out, softmax_lse) if return_softmax else out
 
493
 
494
  @staticmethod
495
  def backward(ctx, dout, *args):
 
520
  dv,
521
  ctx.softmax_scale,
522
  ctx.causal,
523
+ ctx.window_size[0],
524
+ ctx.window_size[1],
525
  ctx.softcap,
526
  ctx.deterministic,
527
  ctx.sm_margin,
528
  )
529
  dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
530
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None, None
531
 
532
 
533
  class FlashAttnFunc(torch.autograd.Function):
 
549
  pack_gqa=None,
550
  deterministic=False,
551
  sm_margin=0,
552
+ return_softmax=False,
553
  ):
554
  if softmax_scale is None:
555
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
 
569
  q_descale, k_descale, v_descale,
570
  softmax_scale,
571
  causal=causal,
572
+ window_size_left=window_size[0],
573
+ window_size_right=window_size[1],
574
  attention_chunk=attention_chunk,
575
  softcap=softcap,
576
  num_splits=num_splits,
 
586
  ctx.softcap = softcap
587
  ctx.deterministic = deterministic
588
  ctx.sm_margin = sm_margin
589
+ return (out, softmax_lse) if return_softmax else out
590
 
591
  @staticmethod
592
  def backward(ctx, dout, *args):
 
608
  dv,
609
  ctx.softmax_scale,
610
  ctx.causal,
611
+ ctx.window_size[0],
612
+ ctx.window_size[1],
613
  ctx.softcap,
614
  ctx.deterministic,
615
  ctx.sm_margin,
 
645
  pack_gqa=None,
646
  deterministic=False,
647
  sm_margin=0,
648
+ return_softmax=False,
649
  ):
650
  if softmax_scale is None:
651
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
 
669
  q_descale, k_descale, v_descale,
670
  softmax_scale,
671
  causal=causal,
672
+ window_size_left=window_size[0],
673
+ window_size_right=window_size[1],
674
  attention_chunk=attention_chunk,
675
  softcap=softcap,
676
  num_splits=num_splits,
 
688
  ctx.softcap = softcap
689
  ctx.deterministic = deterministic
690
  ctx.sm_margin = sm_margin
691
+ return (out, softmax_lse) if return_softmax else out
692
 
693
  @staticmethod
694
  def backward(ctx, dout, *args):
 
713
  dv,
714
  ctx.softmax_scale,
715
  ctx.causal,
716
+ ctx.window_size[0],
717
+ ctx.window_size[1],
718
  ctx.softcap,
719
  ctx.deterministic,
720
  ctx.sm_margin,
 
736
  deterministic=False,
737
  num_heads_q=None,
738
  sm_margin=0,
739
+ return_attn_probs=False,
740
  ):
741
  """dropout_p should be set to 0.0 during evaluation
742
  If Q, K, V are already stacked into 1 tensor, this function will be faster than
 
783
  deterministic,
784
  num_heads_q,
785
  sm_margin,
786
+ return_attn_probs,
787
  )
788
 
789
 
 
802
  pack_gqa=None,
803
  deterministic=False,
804
  sm_margin=0,
805
+ return_attn_probs=False,
806
  ):
807
  """dropout_p should be set to 0.0 during evaluation
808
  Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
 
864
  pack_gqa,
865
  deterministic,
866
  sm_margin,
867
+ return_attn_probs,
868
  )
869
 
870
 
 
889
  pack_gqa=None,
890
  deterministic=False,
891
  sm_margin=0,
892
+ return_attn_probs=False,
893
  ):
894
  return FlashAttnVarlenFunc.apply(
895
  q,
 
912
  pack_gqa,
913
  deterministic,
914
  sm_margin,
915
+ return_attn_probs,
916
  )
917
 
918
 
 
998
  q: (batch_size, seqlen, nheads, headdim)
999
  k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
1000
  or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
1001
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
1002
  v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
1003
  or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
1004
  k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
 
1043
  softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1044
  if cache_seqlens is not None and isinstance(cache_seqlens, int):
1045
  cache_seqlens = torch.full(
1046
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1047
  )
1048
  cache_seqlens = maybe_contiguous(cache_seqlens)
1049
  out, softmax_lse, *rest = _flash_attn_forward(
 
1070
  q_descale, k_descale, v_descale,
1071
  softmax_scale,
1072
  causal=causal,
1073
+ window_size_left=window_size[0],
1074
+ window_size_right=window_size[1],
1075
  attention_chunk=attention_chunk,
1076
  softcap=softcap,
1077
  rotary_interleaved=rotary_interleaved,
build/torch29-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}