Skip to content

Commit c12230b

Browse files
authored
Refine paged_attn with compile (#3641)
1 parent 44eb98a commit c12230b

File tree

7 files changed

+157
-73
lines changed

7 files changed

+157
-73
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ IPEX_DEFINE_DISPATCH(flash_attn_var_len_kernel_stub);
1313
/*
1414
*Caculate the masked multihead attention for decoder layer in decoder only
1515
*/
16-
void single_query_cached_kv_attention_forward_cpu(
16+
at::Tensor single_query_cached_kv_attention_forward_cpu(
1717
at::Tensor& out, // [num_seqs, num_heads, head_size]
1818
at::Tensor& query, // [num_seqs, num_heads, head_size]
1919
at::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
@@ -29,7 +29,7 @@ void single_query_cached_kv_attention_forward_cpu(
2929
const double k_scale,
3030
const double v_scale,
3131
const double softcap) {
32-
return single_query_cached_kv_attention_kernel_stub(
32+
single_query_cached_kv_attention_kernel_stub(
3333
kCPU,
3434
out,
3535
query,
@@ -46,9 +46,10 @@ void single_query_cached_kv_attention_forward_cpu(
4646
k_scale,
4747
v_scale,
4848
softcap);
49+
return out;
4950
}
5051

51-
void reshape_and_cache_cpu(
52+
std::tuple<at::Tensor, at::Tensor> reshape_and_cache_cpu(
5253
at::Tensor& key,
5354
at::Tensor& value,
5455
at::Tensor& key_cache,
@@ -57,7 +58,7 @@ void reshape_and_cache_cpu(
5758
const std::string& kv_cache_dtype,
5859
const double k_scale,
5960
const double v_scale) {
60-
return reshape_and_cache_kernel_stub(
61+
reshape_and_cache_kernel_stub(
6162
kCPU,
6263
key,
6364
value,
@@ -67,9 +68,10 @@ void reshape_and_cache_cpu(
6768
kv_cache_dtype,
6869
k_scale,
6970
v_scale);
71+
return std::make_tuple(key_cache, value_cache);
7072
}
7173

72-
void flash_attn_varlen_cpu(
74+
at::Tensor flash_attn_varlen_cpu(
7375
at::Tensor& out,
7476
at::Tensor& query,
7577
at::Tensor& key,
@@ -84,11 +86,11 @@ void flash_attn_varlen_cpu(
8486
const c10::optional<at::Tensor>& alibi_slopes,
8587
int64_t window_size_left,
8688
int64_t window_size_right,
87-
const std::string& kv_cache_dtype,
89+
const std::string_view& kv_cache_dtype,
8890
const double k_scale,
8991
const double v_scale,
9092
const double softcap) {
91-
return flash_attn_var_len_kernel_stub(
93+
flash_attn_var_len_kernel_stub(
9294
kCPU,
9395
out,
9496
query,
@@ -108,6 +110,7 @@ void flash_attn_varlen_cpu(
108110
k_scale,
109111
v_scale,
110112
softcap);
113+
return out;
111114
}
112115

113116
} // namespace cpu

csrc/cpu/aten/PagedAttention.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace cpu {
88

99
namespace {
1010

11-
void single_query_cached_kv_attention(
11+
at::Tensor single_query_cached_kv_attention_forward_cpu(
1212
at::Tensor& out, // [num_seqs, num_heads, head_size]
1313
at::Tensor& query, // [num_seqs, num_heads, head_size]
1414
at::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
@@ -24,9 +24,8 @@ void single_query_cached_kv_attention(
2424
const double k_scale,
2525
const double v_scale,
2626
const double softcap);
27-
}
2827

29-
void reshape_and_cache(
28+
std::tuple<at::Tensor, at::Tensor> reshape_and_cache_cpu(
3029
at::Tensor& key,
3130
at::Tensor& value,
3231
at::Tensor& key_cache,
@@ -36,7 +35,7 @@ void reshape_and_cache(
3635
const double k_scale,
3736
const double v_scale);
3837

39-
void flash_attn_varlen(
38+
at::Tensor flash_attn_varlen_cpu(
4039
at::Tensor& out,
4140
at::Tensor& query,
4241
at::Tensor& key,
@@ -51,11 +50,13 @@ void flash_attn_varlen(
5150
const c10::optional<at::Tensor>& alibi_slopes,
5251
int64_t window_size_left,
5352
int64_t window_size_right,
54-
const std::string& kv_cache_dtype,
53+
const std::string_view& kv_cache_dtype,
5554
const double k_scale,
5655
const double v_scale,
5756
const double softcap);
5857

58+
} // namespace
59+
5960
using single_query_cached_kv_attention_fn = void (*)(
6061
at::Tensor& out, // [num_seqs, num_heads, head_size]
6162
at::Tensor& query, // [num_seqs, num_heads, head_size]
@@ -98,7 +99,7 @@ using flash_attn_var_len_fn = void (*)(
9899
const c10::optional<at::Tensor>& alibi_slopes,
99100
int64_t window_size_left,
100101
int64_t window_size_right,
101-
const std::string& kv_cache_dtype,
102+
const std::string_view& kv_cache_dtype,
102103
const double k_scale,
103104
const double v_scale,
104105
const double softcap);

csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2016,7 +2016,7 @@ void flash_attn_varlen_cpu_kernel_impl(
20162016
const c10::optional<at::Tensor>& alibi_slopes,
20172017
int64_t window_size_left,
20182018
int64_t window_size_right,
2019-
const std::string& kv_cache_dtype,
2019+
const std::string_view& kv_cache_dtype,
20202020
const double k_scale,
20212021
const double v_scale,
20222022
const double softcap) {

intel_extension_for_pytorch/_meta_registrations.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def is_channels_last_3d(ten):
132132

133133
@register_meta("reshape_and_cache")
134134
def meta_reshape_and_cache(
135-
key, value, key_cache, value_cache, slot_mapping, k_scale, v_scale
135+
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale
136136
):
137-
return None
137+
return key_cache, value_cache
138138

139139

140140
@register_meta("single_query_cached_kv_attention")
@@ -153,8 +153,33 @@ def meta_single_query_cached_kv_attention(
153153
window_size,
154154
k_scale,
155155
v_scale,
156+
softcap,
156157
):
157-
return None
158+
return output
159+
160+
161+
@register_meta("flash_attn_varlen_func")
162+
def meta_flash_attn_varlen_func(
163+
output,
164+
query,
165+
k_cache,
166+
v_cache,
167+
cu_seq_lens_q,
168+
cu_seq_lens_kv,
169+
max_seq_len_q,
170+
max_seq_len_kv,
171+
scale,
172+
is_causal,
173+
block_table,
174+
alibi_slopes,
175+
window_size_left,
176+
window_size_right,
177+
kv_cache_dtype,
178+
k_scale,
179+
v_scale,
180+
softcap,
181+
):
182+
return output
158183

159184

160185
@register_meta("convolution_forward")

intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def reshape_and_cache(
368368
elif kv_cache_dtype != "auto":
369369
raise TypeError("unsupported kv_cache_dtype")
370370

371-
torch.ops.torch_ipex.reshape_and_cache(
371+
return torch.ops.torch_ipex.reshape_and_cache(
372372
key,
373373
value,
374374
key_cache,
@@ -391,7 +391,7 @@ def reshape_and_cache_flash(
391391
k_scale=1.0,
392392
v_scale=1.0,
393393
):
394-
torch.ops.torch_ipex.reshape_and_cache(
394+
return torch.ops.torch_ipex.reshape_and_cache(
395395
key,
396396
value,
397397
key_cache,
@@ -421,7 +421,7 @@ def single_query_cached_kv_attention(
421421
v_scale=1.0,
422422
softcap=-1.0,
423423
):
424-
torch.ops.torch_ipex.single_query_cached_kv_attention(
424+
return torch.ops.torch_ipex.single_query_cached_kv_attention(
425425
output,
426426
query,
427427
key_cache,
@@ -469,7 +469,7 @@ def flash_attn_varlen_func(
469469
raise TypeError("only float8_e5m2 supported")
470470
elif kv_cache_dtype != "auto":
471471
raise TypeError("unsupported kv_cache_dtype")
472-
torch.ops.torch_ipex.flash_attn_varlen_func(
472+
return torch.ops.torch_ipex.flash_attn_varlen_func(
473473
output,
474474
query,
475475
k_cache,

tests/cpu/test_flash_attention_varlen.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def mha_ref(q, k, v, scale, is_causal, window_size, softcap):
7676

7777
class TestFlashAttnVarLen(TestCase):
7878

79-
@torch.inference_mode()
79+
@torch.no_grad()
8080
def _test_flash_attn_varlen(
8181
self,
8282
num_heads: int,
@@ -86,6 +86,7 @@ def _test_flash_attn_varlen(
8686
is_causal: bool,
8787
dtype: torch.dtype,
8888
softcap: float,
89+
is_compile: bool,
8990
) -> None:
9091
random.seed(0)
9192
torch.manual_seed(0)
@@ -163,7 +164,11 @@ def _test_flash_attn_varlen(
163164
output_ref[cu_seq_lens_q[i] : cu_seq_lens_q[i + 1]] = output_i
164165

165166
output = torch.empty_like(query)
166-
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
167+
if is_compile:
168+
f_c = torch.compile(ipex.llm.modules.PagedAttention.flash_attn_varlen_func)
169+
else:
170+
f_c = ipex.llm.modules.PagedAttention.flash_attn_varlen_func
171+
f_c(
167172
output,
168173
query,
169174
k_cache,
@@ -185,7 +190,7 @@ def _test_flash_attn_varlen(
185190
output_ref, output, atol=1e-6 if dtype == torch.float else 5e-2
186191
)
187192

188-
@torch.inference_mode()
193+
@torch.no_grad()
189194
def _test_flash_attn_varlen_fp8(
190195
self,
191196
num_heads: int,
@@ -195,6 +200,7 @@ def _test_flash_attn_varlen_fp8(
195200
is_causal: bool,
196201
dtype: torch.dtype,
197202
softcap: float,
203+
is_compile: bool,
198204
) -> None:
199205
random.seed(0)
200206
torch.manual_seed(0)
@@ -255,6 +261,7 @@ def _test_flash_attn_varlen_fp8(
255261
scale = float(1.0 / (head_size**0.5))
256262

257263
output_ref = torch.empty_like(query)
264+
258265
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
259266
output_ref,
260267
query,
@@ -272,9 +279,12 @@ def _test_flash_attn_varlen_fp8(
272279
window_size[1],
273280
softcap=softcap,
274281
)
275-
282+
if is_compile:
283+
f_c = torch.compile(ipex.llm.modules.PagedAttention.flash_attn_varlen_func)
284+
else:
285+
f_c = ipex.llm.modules.PagedAttention.flash_attn_varlen_func
276286
output = torch.empty_like(query)
277-
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
287+
f_c(
278288
output,
279289
query,
280290
k_cache.to(torch.float8_e5m2),
@@ -297,6 +307,9 @@ def _test_flash_attn_varlen_fp8(
297307
)
298308

299309
def test_flash_attn_varlen(self):
310+
COMPILE_TEST = (
311+
1 # test torch.compile function for once, avoiding recompile in CI
312+
)
300313
for (
301314
num_heads,
302315
num_queries_per_kv,
@@ -314,17 +327,24 @@ def test_flash_attn_varlen(self):
314327
DTYPES,
315328
SOFTCAP,
316329
):
317-
self._test_flash_attn_varlen(
318-
num_heads,
319-
num_queries_per_kv,
320-
head_size,
321-
window_size,
322-
is_causal,
323-
dtype,
324-
softcap,
325-
)
330+
COMPILE = [True, False] if COMPILE_TEST == 1 else [False]
331+
COMPILE_TEST = COMPILE_TEST - 1
332+
for is_compile in COMPILE:
333+
self._test_flash_attn_varlen(
334+
num_heads,
335+
num_queries_per_kv,
336+
head_size,
337+
window_size,
338+
is_causal,
339+
dtype,
340+
softcap,
341+
is_compile,
342+
)
326343

327344
def test_flash_attn_varlen_fp8(self):
345+
COMPILE_TEST = (
346+
1 # test torch.compile function for once, avoiding recompile in CI
347+
)
328348
for (
329349
num_heads,
330350
num_queries_per_kv,
@@ -342,15 +362,19 @@ def test_flash_attn_varlen_fp8(self):
342362
[torch.float, torch.bfloat16],
343363
SOFTCAP,
344364
):
345-
self._test_flash_attn_varlen_fp8(
346-
num_heads,
347-
num_queries_per_kv,
348-
head_size,
349-
window_size,
350-
is_causal,
351-
dtype,
352-
softcap,
353-
)
365+
COMPILE = [True, False] if COMPILE_TEST == 1 else [False]
366+
COMPILE_TEST = COMPILE_TEST - 1
367+
for is_compile in COMPILE:
368+
self._test_flash_attn_varlen_fp8(
369+
num_heads,
370+
num_queries_per_kv,
371+
head_size,
372+
window_size,
373+
is_causal,
374+
dtype,
375+
softcap,
376+
is_compile,
377+
)
354378

355379

356380
if __name__ == "__main__":

0 commit comments

Comments
 (0)