@@ -207,55 +207,55 @@ This is exactly what happens in the widespread case of [FlashAttention version 2
207207The FlashAttention v2 Forward pass algorithm in pseudo-code is:
208208
209209``` python
210- # Inputs : Q, K and V are 2D Matrices in Global Memory
211- def FlashAttention2_forward (Q , K , V ):
212- O = torch.zeros_like(Q, requires_grad = True )
213- L = torch.zeros(Q.shape[:- 1 ])[... ,None ]
214-
215- Q_BLOCKS = torch.split(Q, BLOCK_SHAPE )
216- K_BLOCKS = torch.split(K, BLOCK_SHAPE )
217- V_BLOCKS = torch.split(V, BLOCK_SHAPE )
218-
219- Tr = len (Q_BLOCKS )
220- Tc = len (K_BLOCKS )
221-
222- for i in range (Tr):
223- Qi = load(Q_BLOCKS [i]) # Load data from Global Memory to SRAM
224- Oi = torch.zeros(BLOCK_SHAPE ) # No load required, Initialized on chip
225- li = torch.zeros(BLOCK_SHAPE ) # No load required, Initialized on chip
226- mi = NEG_INF # No load required, Initialized on chip
227-
228- for j in range (Tc):
229- Kj = load(K_BLOCKS [j]) # Load data from Global Memory to SRAM
230- Vj = load(V_BLOCKS [j]) # Load data from Global Memory to SRAM
231-
232- KTj = Kj.transpose()
233- S_ij = matmul(Qi, KTj)
234-
235- P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)
236-
237- P_ij_Vj = matmul(P_ij, Vj)
238- Oij = (li/ li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
239-
240- # update li and mi
241- li = li_new
242- mi = mi_new
243-
244- Oi = Oij / diag(li)
245- O.store(Oi, i) # Store data to Global Memory as the i-th block of O
246- L.store(li, i) # Store data to Global Memory as the i-th block of L
247-
248- return O, L
210+ 1 # Inputs : Q, K and V are 2D Matrices in Global Memory
211+ 2 def FlashAttention2_forward (Q , K , V ):
212+ 3 O = torch.zeros_like(Q, requires_grad = True )
213+ 4 L = torch.zeros(Q.shape[:- 1 ])[... ,None ]
214+ 5
215+ 6 Q_BLOCKS = torch.split(Q, BLOCK_SHAPE )
216+ 7 K_BLOCKS = torch.split(K, BLOCK_SHAPE )
217+ 8 V_BLOCKS = torch.split(V, BLOCK_SHAPE )
218+ 9
219+ 10 Tr = len (Q_BLOCKS )
220+ 11 Tc = len (K_BLOCKS )
221+ 12
222+ 13 for i in range (Tr):
223+ 14 Qi = load(Q_BLOCKS [i]) # Load data from Global Memory to SRAM
224+ 15 Oi = torch.zeros(BLOCK_SHAPE ) # No load required, Initialized on chip
225+ 16 li = torch.zeros(BLOCK_SHAPE ) # No load required, Initialized on chip
226+ 17 mi = NEG_INF # No load required, Initialized on chip
227+ 18
228+ 19 for j in range (Tc):
229+ 20 Kj = load(K_BLOCKS [j]) # Load data from Global Memory to SRAM
230+ 21 Vj = load(V_BLOCKS [j]) # Load data from Global Memory to SRAM
231+ 22
232+ 23 KTj = Kj.transpose()
233+ 24 S_ij = matmul(Qi, KTj)
234+ 25
235+ 26 P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)
236+ 27
237+ 28 P_ij_Vj = matmul(P_ij, Vj)
238+ 29 Oij = (li/ li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
239+ 30
240+ 31 # update li and mi
241+ 32 li = li_new
242+ 33 mi = mi_new
243+ 34
244+ 35 Oi = Oij / diag(li)
245+ 36 O.store(Oi, i) # Store data to Global Memory as the i-th block of O
246+ 37 L.store(li, i) # Store data to Global Memory as the i-th block of L
247+ 38
248+ 39 return O, L
249249```
250250
251251In the second version of the implementation of the FlashAttention model, the loop order has been reversed to promote
252252data locality.
253253As long as there is enough local memory (or registers) to contain all the needed data, this algorithm works fine and
254254provides significant performance improvements compared to FlashAttention v1 (in the paper, the authors mention 2x faster
255255for the Cutlass implementation and 1.3-1.5× faster in Triton on an Nvidia Ampere GPU A100).
256- Deployed on a GPU target, line 4-10 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
256+ Deployed on a GPU target, line 13-37 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
257257i.e. a SM/XeCore).
258- But as you can see, variable Q is loaded before the loop (line 4 ) and remains * live* across the loop.
258+ As you can see, variable Q is loaded before the loop (line 14 ) and remains * live* across the loop.
259259
260260The long lifespan of variable Q is even more problematic in the causal variation of the FlashAttention implementation.
261261The causal variation is defined in the paper as :
@@ -264,7 +264,7 @@ The causal variation is defined in the paper as :
264264
265265The Triton implementation of FlashAttention v2 with causal mask is as follow:
266266
267- ``` python {.line-numbers}
267+ ``` python
268268@triton.jit
269269def _attn_fwd (Q_block_ptr , K_block_ptr , V_block_ptr , sm_scale , M , N_CTX : tl.constexpr, #
270270 BLOCK_M : tl.constexpr, BLOCK_DMODEL : tl.constexpr, #
0 commit comments