Skip to content

Commit c9450a3

Browse files
authored
add fp16 support for flash_varlen kernel (#3416)
1 parent 950e509 commit c9450a3

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,57 @@ void flash_attn_varlen_cpu_kernel_impl(
12801280
v_scale);
12811281
}
12821282

1283+
} else if (query.scalar_type() == at::ScalarType::Half) {
1284+
if (max_seqlen_q >= 768) {
1285+
flash_attn_varlen_kernel<at::Half, at::Half, 128>(
1286+
out,
1287+
query,
1288+
key,
1289+
value,
1290+
cu_seqlens_q,
1291+
cu_seqlens_kv,
1292+
max_seqlen_q,
1293+
max_seqlen_kv,
1294+
softmax_scale,
1295+
is_causal,
1296+
block_table,
1297+
alibi_slopes,
1298+
k_scale,
1299+
v_scale);
1300+
} else if (max_seqlen_q >= 192) {
1301+
flash_attn_varlen_kernel<at::Half, at::Half, 64>(
1302+
out,
1303+
query,
1304+
key,
1305+
value,
1306+
cu_seqlens_q,
1307+
cu_seqlens_kv,
1308+
max_seqlen_q,
1309+
max_seqlen_kv,
1310+
softmax_scale,
1311+
is_causal,
1312+
block_table,
1313+
alibi_slopes,
1314+
k_scale,
1315+
v_scale);
1316+
} else {
1317+
flash_attn_varlen_kernel<at::Half, at::Half, 32>(
1318+
out,
1319+
query,
1320+
key,
1321+
value,
1322+
cu_seqlens_q,
1323+
cu_seqlens_kv,
1324+
max_seqlen_q,
1325+
max_seqlen_kv,
1326+
softmax_scale,
1327+
is_causal,
1328+
block_table,
1329+
alibi_slopes,
1330+
k_scale,
1331+
v_scale);
1332+
}
1333+
12831334
} else {
12841335
TORCH_CHECK(false, "Unsupported data type for ipex::flash_attn_varlen");
12851336
}

tests/cpu/test_flash_attention_varlen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
NUM_HEADS = [32]
99
NUM_QUERIES_PER_KV = [1, 4]
1010
HEAD_SIZES = [128, 64]
11-
DTYPES = [torch.bfloat16, torch.float32]
11+
DTYPES = [torch.bfloat16, torch.float32, torch.float16]
1212

1313

1414
def mha_ref(q, k, v, scale, is_causal):

0 commit comments

Comments
 (0)