Skip to content

Commit 620a9bf

Browse files
authored
[flash attention] fix bugs for attention mask (#2987)
1 parent 52f8c48 commit 620a9bf

File tree

2 files changed

+287
-103
lines changed

2 files changed

+287
-103
lines changed

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 181 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -283,76 +283,52 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
283283
#endif
284284

285285
// out = val * a + b
286-
template <typename T1, typename T2>
286+
// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
287+
// take b as a scalar pointer.
288+
template <bool is_b_stride_zero, typename T1, typename T2>
287289
inline void _scale_attn_mask_fusion_kernel(
288290
T1* a,
289291
T2* b,
290292
const int& size,
291293
T1* out,
292294
T1& val) {
293-
auto vec_size = at::vec::Vectorized<T1>::size();
294-
auto vec_scale = at::vec::Vectorized<T1>(val);
295-
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
296-
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
297-
auto tmp1 = at::vec::Vectorized<T2>::loadu(b + i);
298-
auto tmp2 = at::vec::convert<T1>(tmp1);
299-
auto tmp3 = tmp0 * vec_scale + tmp2;
300-
_store(out + i, tmp3);
301-
}
302-
for (long i = vec_size * (size / vec_size); i < size; i++) {
303-
auto tmp0 = a[i];
304-
auto tmp1 = (T1)b[i];
305-
out[i] = tmp0 * val + tmp1;
306-
}
307-
}
308-
309-
// out = val * a + b
310-
template <typename T1>
311-
inline void _scale_attn_mask_fusion_kernel(
312-
T1* a,
313-
T1* b,
314-
const int& size,
315-
T1* out,
316-
T1& val) {
317-
auto vec_size = at::vec::Vectorized<T1>::size();
318-
auto vec_scale = at::vec::Vectorized<T1>(val);
319-
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
320-
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
321-
auto tmp1 = at::vec::Vectorized<T1>::loadu(b + i);
322-
auto tmp2 = tmp0 * vec_scale + tmp1;
323-
_store(out + i, tmp2);
324-
}
325-
for (long i = vec_size * (size / vec_size); i < size; i++) {
326-
auto tmp0 = a[i];
327-
auto tmp1 = b[i];
328-
out[i] = tmp0 * val + tmp1;
329-
}
330-
}
331-
332-
// out = b ? val * a : -inf
333-
template <typename T1>
334-
inline void _scale_attn_mask_fusion_kernel(
335-
T1* a,
336-
bool* b,
337-
const int& size,
338-
T1* out,
339-
T1& val) {
340-
auto vec_size = at::vec::Vectorized<T1>::size();
341-
auto vec_scale = at::vec::Vectorized<T1>(val);
342-
auto neg_inf = -std::numeric_limits<T1>::infinity();
343-
auto vec_neg_inf = at::vec::Vectorized<T1>(neg_inf);
344-
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
345-
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
346-
auto tmp1 = at::vec::Vectorized<bool>::loadu(b + i);
347-
auto tmp2 = at::vec::convert<T1>(tmp1);
348-
auto tmp3 =
349-
at::vec::Vectorized<T1>::blendv(vec_neg_inf, tmp0 * vec_scale, tmp2);
350-
_store(out + i, tmp3);
351-
}
352-
for (long i = vec_size * (size / vec_size); i < size; i++) {
353-
auto tmp0 = a[i];
354-
auto tmp1 = b[i];
355-
out[i] = tmp1 ? tmp0 * val : neg_inf;
295+
const auto vec_size1 = at::vec::Vectorized<T1>::size();
296+
const auto vec_size2 = at::vec::Vectorized<T2>::size();
297+
constexpr int64_t T1_n =
298+
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
299+
constexpr int64_t T2_n = 1;
300+
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
301+
int64_t i = 0;
302+
if (is_b_stride_zero) {
303+
auto b_first_val = (T1)b[0];
304+
auto b_first_vec = at::vec::VectorizedN<T2, T2_n>(b_first_val);
305+
for (; i < size - (size % vec_size2); i += vec_size2) {
306+
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
307+
auto b_n = b_first_vec;
308+
at::vec::VectorizedN<T1, T1_n> b_n_convert =
309+
at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
310+
auto res = a_n * vec_scale + b_n_convert;
311+
res.store(out + i);
312+
}
313+
for (; i < size; i++) {
314+
auto tmp0 = a[i];
315+
auto tmp1 = b_first_val;
316+
out[i] = tmp0 * val + tmp1;
317+
}
318+
} else {
319+
for (; i < size - (size % vec_size2); i += vec_size2) {
320+
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
321+
auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
322+
at::vec::VectorizedN<T1, T1_n> b_n_convert =
323+
at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
324+
auto res = a_n * vec_scale + b_n_convert;
325+
res.store(out + i);
326+
}
327+
for (; i < size; i++) {
328+
auto tmp0 = a[i];
329+
auto tmp1 = (T1)b[i];
330+
out[i] = tmp0 * val + tmp1;
331+
}
356332
}
357333
}
358334

@@ -425,6 +401,82 @@ inline void _mul_reduce_max_fusion_kernel(
425401
vec_tmp_max));
426402
}
427403

404+
// This function is used to produce an attn_mask in a standard format
405+
inline std::optional<at::Tensor> convert_boolean_attn_mask(
406+
const std::optional<at::Tensor>& attn_mask,
407+
caffe2::TypeMeta dtype) {
408+
// Pass through
409+
if (!attn_mask.has_value()) {
410+
return c10::nullopt;
411+
}
412+
// Convert boolean mask to additive mask
413+
if (attn_mask->dtype() == at::kBool) {
414+
auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
415+
new_attn_mask.masked_fill_(
416+
attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
417+
return new_attn_mask;
418+
}
419+
// Otherwise, attn_mask represents an additive attention tensor
420+
return attn_mask;
421+
}
422+
423+
// Support mask shapes:
424+
// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
425+
// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
426+
inline bool check_attn_mask_shape(
427+
at::Tensor& attn_mask,
428+
int64_t batchSize,
429+
int64_t num_head,
430+
int64_t qSize,
431+
int64_t kvSize) {
432+
if (attn_mask.size(-2) != qSize && attn_mask.size(-2) != 1) {
433+
return false;
434+
}
435+
if (attn_mask.size(-1) != kvSize && attn_mask.size(-1) != 1) {
436+
return false;
437+
}
438+
if (attn_mask.dim() == 2) {
439+
return true;
440+
} else if (attn_mask.dim() == 4) {
441+
if ((attn_mask.size(0) == 1 || attn_mask.size(0) == batchSize) &&
442+
(attn_mask.size(1) == 1 || attn_mask.size(1) == num_head)) {
443+
return true;
444+
}
445+
}
446+
return false;
447+
}
448+
449+
// Reshape attention mask to 4d
450+
inline void reshape_attn_mask_to_4d(
451+
at::Tensor& attn_mask,
452+
int64_t batchSize,
453+
int64_t num_head,
454+
int64_t qSize,
455+
int64_t kvSize) {
456+
TORCH_CHECK(
457+
check_attn_mask_shape(attn_mask, batchSize, num_head, qSize, kvSize),
458+
"IPEX flash_attention: Please use the following attn mask shapes: ",
459+
"2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
460+
"4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
461+
int64_t attn_mask_size_0 = 1;
462+
int64_t attn_mask_size_1 = 1;
463+
if (attn_mask.dim() == 4) {
464+
if (attn_mask.size(0) == batchSize) {
465+
attn_mask_size_0 = batchSize;
466+
}
467+
if (attn_mask.size(1) == num_head) {
468+
attn_mask_size_1 = num_head;
469+
}
470+
}
471+
attn_mask = attn_mask
472+
.view(
473+
{attn_mask_size_0,
474+
attn_mask_size_1,
475+
attn_mask.size(-2),
476+
attn_mask.size(-1)})
477+
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
478+
}
479+
428480
/*
429481
*Caculate the flash attention SDPA.
430482
*@template scalar_t: q/k/v data type
@@ -480,6 +532,12 @@ cpu_flash_attention(
480532
int64_t num_head = query.size(2);
481533
int64_t headSize = query.size(3);
482534

535+
// reshape mask
536+
if (attention_mask.has_value()) {
537+
reshape_attn_mask_to_4d(
538+
attention_mask.value(), batchSize, num_head, qSize, kvSize);
539+
}
540+
483541
// Strides
484542
int64_t qStrideB = query.stride(0);
485543
int64_t qStrideM = query.stride(1);
@@ -505,7 +563,13 @@ cpu_flash_attention(
505563
? attention_mask.value().stride(1)
506564
: 0;
507565
int64_t mStrideM =
508-
attention_mask.has_value() ? attention_mask.value().stride(2) : 0;
566+
(attention_mask.has_value() && attention_mask.value().size(2) > 1)
567+
? attention_mask.value().stride(2)
568+
: 0;
569+
int64_t mStrideN =
570+
(attention_mask.has_value() && attention_mask.value().size(3) > 1)
571+
? attention_mask.value().stride(3)
572+
: 0;
509573

510574
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
511575
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -596,15 +660,24 @@ cpu_flash_attention(
596660
// And apply scaling factor
597661
if (attention_mask.has_value()) {
598662
for (int64_t row = 0; row < qBlockSize; ++row) {
599-
// qk <- attn_mask ? qk : -inf, if attn_mask is bool
600-
// qk <- qk + attn_mask, else
601-
_scale_attn_mask_fusion_kernel(
602-
qk_data + row * kvBlockSize,
603-
mask_data + i * mStrideB + j * mStrideH +
604-
(m + row) * mStrideM + n,
605-
kvBlockSize,
606-
qk_data + row * kvBlockSize,
607-
scaling_factor);
663+
// qk <- qk * scaling_factor + attn_mask, else
664+
if (mStrideN == 0) {
665+
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ true>(
666+
qk_data + row * kvBlockSize,
667+
mask_data + i * mStrideB + j * mStrideH +
668+
(m + row) * mStrideM,
669+
kvBlockSize,
670+
qk_data + row * kvBlockSize,
671+
scaling_factor);
672+
} else {
673+
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ false>(
674+
qk_data + row * kvBlockSize,
675+
mask_data + i * mStrideB + j * mStrideH +
676+
(m + row) * mStrideM + n,
677+
kvBlockSize,
678+
qk_data + row * kvBlockSize,
679+
scaling_factor);
680+
}
608681
}
609682
}
610683
// Update coefficients with Softmax
@@ -737,6 +810,12 @@ cpu_flash_attention(
737810
int64_t num_head = query.size(2);
738811
int64_t headSize = query.size(3);
739812

813+
// reshape mask
814+
if (attention_mask.has_value()) {
815+
reshape_attn_mask_to_4d(
816+
attention_mask.value(), batchSize, num_head, qSize, kvSize);
817+
}
818+
740819
// Strides
741820
int64_t qStrideB = query.stride(0);
742821
int64_t qStrideM = query.stride(1);
@@ -762,7 +841,13 @@ cpu_flash_attention(
762841
? attention_mask.value().stride(1)
763842
: 0;
764843
int64_t mStrideM =
765-
attention_mask.has_value() ? attention_mask.value().stride(2) : 0;
844+
(attention_mask.has_value() && attention_mask.value().size(2) > 1)
845+
? attention_mask.value().stride(2)
846+
: 0;
847+
int64_t mStrideN =
848+
(attention_mask.has_value() && attention_mask.value().size(3) > 1)
849+
? attention_mask.value().stride(3)
850+
: 0;
766851

767852
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
768853
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -1241,15 +1326,24 @@ cpu_flash_attention(
12411326
// And apply scaling factor
12421327
if (attention_mask.has_value()) {
12431328
for (int64_t row = 0; row < qBlockSize; ++row) {
1244-
// qk <- attn_mask ? qk : -inf, if attn_mask is bool
1245-
// qk <- qk + attn_mask, else
1246-
_scale_attn_mask_fusion_kernel(
1247-
qk_data + row * kvBlockSize,
1248-
mask_data + i * mStrideB + j * mStrideH +
1249-
(m + row) * mStrideM + n,
1250-
kvBlockSize,
1251-
qk_data + row * kvBlockSize,
1252-
scaling_factor);
1329+
// qk <- qk * scaling_factor + attn_mask, else
1330+
if (mStrideN == 0) {
1331+
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ true>(
1332+
qk_data + row * kvBlockSize,
1333+
mask_data + i * mStrideB + j * mStrideH +
1334+
(m + row) * mStrideM,
1335+
kvBlockSize,
1336+
qk_data + row * kvBlockSize,
1337+
scaling_factor);
1338+
} else {
1339+
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ false>(
1340+
qk_data + row * kvBlockSize,
1341+
mask_data + i * mStrideB + j * mStrideH +
1342+
(m + row) * mStrideM + n,
1343+
kvBlockSize,
1344+
qk_data + row * kvBlockSize,
1345+
scaling_factor);
1346+
}
12531347
}
12541348
}
12551349
// Update coefficients with Softmax
@@ -1558,6 +1652,8 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
15581652
attention_mask.value().stride(-1) == 1),
15591653
"IPEX flash_attention: Q/K/V/Mask should be continuous on the last dim");
15601654

1655+
std::optional<at::Tensor> attn_mask =
1656+
convert_boolean_attn_mask(attention_mask, query.dtype());
15611657
at::Tensor output =
15621658
at::empty({batchSize, qSize, num_head, headSize}, query.options());
15631659
const auto accumulate_dtype = at::toOpMathType(dtype);
@@ -1572,7 +1668,7 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
15721668
value,
15731669
dropout_p,
15741670
is_causal,
1575-
attention_mask,
1671+
attn_mask,
15761672
scale);
15771673

15781674
output = output.transpose(1, 2);

0 commit comments

Comments
 (0)