Skip to content

Commit 28f896a

Browse files
Merge pull request #8 from mfrancepillois/maxime/fixLineNumbers
Fix Line Number issue.
2 parents 93bcc4c + c557885 commit 28f896a

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -207,55 +207,55 @@ This is exactly what happens in the widespread case of [FlashAttention version 2
207207
The FlashAttention v2 Forward pass algorithm in pseudo-code is:
208208

209209
```python
210-
# Inputs : Q, K and V are 2D Matrices in Global Memory
211-
def FlashAttention2_forward(Q, K, V):
212-
O = torch.zeros_like(Q, requires_grad=True)
213-
L = torch.zeros(Q.shape[:-1])[...,None]
214-
215-
Q_BLOCKS = torch.split(Q, BLOCK_SHAPE)
216-
K_BLOCKS = torch.split(K, BLOCK_SHAPE)
217-
V_BLOCKS = torch.split(V, BLOCK_SHAPE)
218-
219-
Tr = len(Q_BLOCKS)
220-
Tc = len(K_BLOCKS)
221-
222-
for i in range(Tr):
223-
Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM
224-
Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
225-
li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
226-
mi = NEG_INF # No load required, Initialized on chip
227-
228-
for j in range(Tc):
229-
Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM
230-
Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM
231-
232-
KTj = Kj.transpose()
233-
S_ij = matmul(Qi, KTj)
234-
235-
P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)
236-
237-
P_ij_Vj = matmul(P_ij, Vj)
238-
Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
239-
240-
# update li and mi
241-
li = li_new
242-
mi = mi_new
243-
244-
Oi = Oij / diag(li)
245-
O.store(Oi, i) # Store data to Global Memory as the i-th block of O
246-
L.store(li, i) # Store data to Global Memory as the i-th block of L
247-
248-
return O, L
210+
1 # Inputs : Q, K and V are 2D Matrices in Global Memory
211+
2 def FlashAttention2_forward(Q, K, V):
212+
3 O = torch.zeros_like(Q, requires_grad=True)
213+
4 L = torch.zeros(Q.shape[:-1])[...,None]
214+
5
215+
6 Q_BLOCKS = torch.split(Q, BLOCK_SHAPE)
216+
7 K_BLOCKS = torch.split(K, BLOCK_SHAPE)
217+
8 V_BLOCKS = torch.split(V, BLOCK_SHAPE)
218+
9
219+
10 Tr = len(Q_BLOCKS)
220+
11 Tc = len(K_BLOCKS)
221+
12
222+
13 for i in range(Tr):
223+
14 Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM
224+
15 Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
225+
16 li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
226+
17 mi = NEG_INF # No load required, Initialized on chip
227+
18
228+
19 for j in range(Tc):
229+
20 Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM
230+
21 Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM
231+
22
232+
23 KTj = Kj.transpose()
233+
24 S_ij = matmul(Qi, KTj)
234+
25
235+
26 P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)
236+
27
237+
28 P_ij_Vj = matmul(P_ij, Vj)
238+
29 Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
239+
30
240+
31 # update li and mi
241+
32 li = li_new
242+
33 mi = mi_new
243+
34
244+
35 Oi = Oij / diag(li)
245+
36 O.store(Oi, i) # Store data to Global Memory as the i-th block of O
246+
37 L.store(li, i) # Store data to Global Memory as the i-th block of L
247+
38
248+
39 return O, L
249249
```
250250

251251
In the second version of the implementation of the FlashAttention model, the loop order has been reversed to promote
252252
data locality.
253253
As long as there is enough local memory (or registers) to contain all the needed data, this algorithm works fine and
254254
provides significant performance improvements compared to FlashAttention v1 (in the paper, the authors mention 2x faster
255255
for the Cutlass implementation and 1.3-1.5× faster in Triton on an Nvidia Ampere GPU A100).
256-
Deployed on a GPU target, line 4-10 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
256+
Deployed on a GPU target, line 13-37 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
257257
i.e. a SM/XeCore).
258-
But as you can see, variable Q is loaded before the loop (line 4) and remains *live* across the loop.
258+
As you can see, variable Q is loaded before the loop (line 14) and remains *live* across the loop.
259259

260260
The long lifespan of variable Q is even more problematic in the causal variation of the FlashAttention implementation.
261261
The causal variation is defined in the paper as :
@@ -264,7 +264,7 @@ The causal variation is defined in the paper as :
264264
265265
The Triton implementation of FlashAttention v2 with causal mask is as follow:
266266

267-
```python {.line-numbers}
267+
```python
268268
@triton.jit
269269
def _attn_fwd(Q_block_ptr, K_block_ptr, V_block_ptr, sm_scale, M, N_CTX: tl.constexpr, #
270270
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #

0 commit comments

Comments
 (0)