@@ -283,76 +283,52 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
283
283
#endif
284
284
285
285
// out = val * a + b
286
- template <typename T1, typename T2>
286
+ // is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
287
+ // take b as a scalar pointer.
288
+ template <bool is_b_stride_zero, typename T1, typename T2>
287
289
inline void _scale_attn_mask_fusion_kernel (
288
290
T1* a,
289
291
T2* b,
290
292
const int & size,
291
293
T1* out,
292
294
T1& val) {
293
- auto vec_size = at::vec::Vectorized<T1>::size ();
294
- auto vec_scale = at::vec::Vectorized<T1>(val);
295
- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
296
- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
297
- auto tmp1 = at::vec::Vectorized<T2>::loadu (b + i);
298
- auto tmp2 = at::vec::convert<T1>(tmp1);
299
- auto tmp3 = tmp0 * vec_scale + tmp2;
300
- _store (out + i, tmp3);
301
- }
302
- for (long i = vec_size * (size / vec_size); i < size; i++) {
303
- auto tmp0 = a[i];
304
- auto tmp1 = (T1)b[i];
305
- out[i] = tmp0 * val + tmp1;
306
- }
307
- }
308
-
309
- // out = val * a + b
310
- template <typename T1>
311
- inline void _scale_attn_mask_fusion_kernel (
312
- T1* a,
313
- T1* b,
314
- const int & size,
315
- T1* out,
316
- T1& val) {
317
- auto vec_size = at::vec::Vectorized<T1>::size ();
318
- auto vec_scale = at::vec::Vectorized<T1>(val);
319
- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
320
- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
321
- auto tmp1 = at::vec::Vectorized<T1>::loadu (b + i);
322
- auto tmp2 = tmp0 * vec_scale + tmp1;
323
- _store (out + i, tmp2);
324
- }
325
- for (long i = vec_size * (size / vec_size); i < size; i++) {
326
- auto tmp0 = a[i];
327
- auto tmp1 = b[i];
328
- out[i] = tmp0 * val + tmp1;
329
- }
330
- }
331
-
332
- // out = b ? val * a : -inf
333
- template <typename T1>
334
- inline void _scale_attn_mask_fusion_kernel (
335
- T1* a,
336
- bool * b,
337
- const int & size,
338
- T1* out,
339
- T1& val) {
340
- auto vec_size = at::vec::Vectorized<T1>::size ();
341
- auto vec_scale = at::vec::Vectorized<T1>(val);
342
- auto neg_inf = -std::numeric_limits<T1>::infinity ();
343
- auto vec_neg_inf = at::vec::Vectorized<T1>(neg_inf);
344
- for (long i = 0 ; i < vec_size * (size / vec_size); i += vec_size) {
345
- auto tmp0 = at::vec::Vectorized<T1>::loadu (a + i);
346
- auto tmp1 = at::vec::Vectorized<bool >::loadu (b + i);
347
- auto tmp2 = at::vec::convert<T1>(tmp1);
348
- auto tmp3 =
349
- at::vec::Vectorized<T1>::blendv (vec_neg_inf, tmp0 * vec_scale, tmp2);
350
- _store (out + i, tmp3);
351
- }
352
- for (long i = vec_size * (size / vec_size); i < size; i++) {
353
- auto tmp0 = a[i];
354
- auto tmp1 = b[i];
355
- out[i] = tmp1 ? tmp0 * val : neg_inf;
295
+ const auto vec_size1 = at::vec::Vectorized<T1>::size ();
296
+ const auto vec_size2 = at::vec::Vectorized<T2>::size ();
297
+ constexpr int64_t T1_n =
298
+ (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1 ;
299
+ constexpr int64_t T2_n = 1 ;
300
+ auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
301
+ int64_t i = 0 ;
302
+ if (is_b_stride_zero) {
303
+ auto b_first_val = (T1)b[0 ];
304
+ auto b_first_vec = at::vec::VectorizedN<T2, T2_n>(b_first_val);
305
+ for (; i < size - (size % vec_size2); i += vec_size2) {
306
+ auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu (a + i);
307
+ auto b_n = b_first_vec;
308
+ at::vec::VectorizedN<T1, T1_n> b_n_convert =
309
+ at::vec::convert<T1, T1_n, T2, T2_n, true >(b_n);
310
+ auto res = a_n * vec_scale + b_n_convert;
311
+ res.store (out + i);
312
+ }
313
+ for (; i < size; i++) {
314
+ auto tmp0 = a[i];
315
+ auto tmp1 = b_first_val;
316
+ out[i] = tmp0 * val + tmp1;
317
+ }
318
+ } else {
319
+ for (; i < size - (size % vec_size2); i += vec_size2) {
320
+ auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu (a + i);
321
+ auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu (b + i);
322
+ at::vec::VectorizedN<T1, T1_n> b_n_convert =
323
+ at::vec::convert<T1, T1_n, T2, T2_n, true >(b_n);
324
+ auto res = a_n * vec_scale + b_n_convert;
325
+ res.store (out + i);
326
+ }
327
+ for (; i < size; i++) {
328
+ auto tmp0 = a[i];
329
+ auto tmp1 = (T1)b[i];
330
+ out[i] = tmp0 * val + tmp1;
331
+ }
356
332
}
357
333
}
358
334
@@ -425,6 +401,82 @@ inline void _mul_reduce_max_fusion_kernel(
425
401
vec_tmp_max));
426
402
}
427
403
404
+ // This function is used to produce an attn_mask in a standard format
405
+ inline std::optional<at::Tensor> convert_boolean_attn_mask (
406
+ const std::optional<at::Tensor>& attn_mask,
407
+ caffe2::TypeMeta dtype) {
408
+ // Pass through
409
+ if (!attn_mask.has_value ()) {
410
+ return c10::nullopt;
411
+ }
412
+ // Convert boolean mask to additive mask
413
+ if (attn_mask->dtype () == at::kBool ) {
414
+ auto new_attn_mask = at::zeros_like (attn_mask.value (), dtype);
415
+ new_attn_mask.masked_fill_ (
416
+ attn_mask->logical_not (), -std::numeric_limits<double >::infinity ());
417
+ return new_attn_mask;
418
+ }
419
+ // Otherwise, attn_mask represents an additive attention tensor
420
+ return attn_mask;
421
+ }
422
+
423
+ // Support mask shapes:
424
+ // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
425
+ // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
426
+ inline bool check_attn_mask_shape (
427
+ at::Tensor& attn_mask,
428
+ int64_t batchSize,
429
+ int64_t num_head,
430
+ int64_t qSize,
431
+ int64_t kvSize) {
432
+ if (attn_mask.size (-2 ) != qSize && attn_mask.size (-2 ) != 1 ) {
433
+ return false ;
434
+ }
435
+ if (attn_mask.size (-1 ) != kvSize && attn_mask.size (-1 ) != 1 ) {
436
+ return false ;
437
+ }
438
+ if (attn_mask.dim () == 2 ) {
439
+ return true ;
440
+ } else if (attn_mask.dim () == 4 ) {
441
+ if ((attn_mask.size (0 ) == 1 || attn_mask.size (0 ) == batchSize) &&
442
+ (attn_mask.size (1 ) == 1 || attn_mask.size (1 ) == num_head)) {
443
+ return true ;
444
+ }
445
+ }
446
+ return false ;
447
+ }
448
+
449
+ // Reshape attention mask to 4d
450
+ inline void reshape_attn_mask_to_4d (
451
+ at::Tensor& attn_mask,
452
+ int64_t batchSize,
453
+ int64_t num_head,
454
+ int64_t qSize,
455
+ int64_t kvSize) {
456
+ TORCH_CHECK (
457
+ check_attn_mask_shape (attn_mask, batchSize, num_head, qSize, kvSize),
458
+ " IPEX flash_attention: Please use the following attn mask shapes: " ,
459
+ " 2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); " ,
460
+ " 4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})" );
461
+ int64_t attn_mask_size_0 = 1 ;
462
+ int64_t attn_mask_size_1 = 1 ;
463
+ if (attn_mask.dim () == 4 ) {
464
+ if (attn_mask.size (0 ) == batchSize) {
465
+ attn_mask_size_0 = batchSize;
466
+ }
467
+ if (attn_mask.size (1 ) == num_head) {
468
+ attn_mask_size_1 = num_head;
469
+ }
470
+ }
471
+ attn_mask = attn_mask
472
+ .view (
473
+ {attn_mask_size_0,
474
+ attn_mask_size_1,
475
+ attn_mask.size (-2 ),
476
+ attn_mask.size (-1 )})
477
+ .expand ({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
478
+ }
479
+
428
480
/*
429
481
*Caculate the flash attention SDPA.
430
482
*@template scalar_t: q/k/v data type
@@ -480,6 +532,12 @@ cpu_flash_attention(
480
532
int64_t num_head = query.size (2 );
481
533
int64_t headSize = query.size (3 );
482
534
535
+ // reshape mask
536
+ if (attention_mask.has_value ()) {
537
+ reshape_attn_mask_to_4d (
538
+ attention_mask.value (), batchSize, num_head, qSize, kvSize);
539
+ }
540
+
483
541
// Strides
484
542
int64_t qStrideB = query.stride (0 );
485
543
int64_t qStrideM = query.stride (1 );
@@ -505,7 +563,13 @@ cpu_flash_attention(
505
563
? attention_mask.value ().stride (1 )
506
564
: 0 ;
507
565
int64_t mStrideM =
508
- attention_mask.has_value () ? attention_mask.value ().stride (2 ) : 0 ;
566
+ (attention_mask.has_value () && attention_mask.value ().size (2 ) > 1 )
567
+ ? attention_mask.value ().stride (2 )
568
+ : 0 ;
569
+ int64_t mStrideN =
570
+ (attention_mask.has_value () && attention_mask.value ().size (3 ) > 1 )
571
+ ? attention_mask.value ().stride (3 )
572
+ : 0 ;
509
573
510
574
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
511
575
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -596,15 +660,24 @@ cpu_flash_attention(
596
660
// And apply scaling factor
597
661
if (attention_mask.has_value ()) {
598
662
for (int64_t row = 0 ; row < qBlockSize; ++row) {
599
- // qk <- attn_mask ? qk : -inf, if attn_mask is bool
600
- // qk <- qk + attn_mask, else
601
- _scale_attn_mask_fusion_kernel (
602
- qk_data + row * kvBlockSize,
603
- mask_data + i * mStrideB + j * mStrideH +
604
- (m + row) * mStrideM + n,
605
- kvBlockSize,
606
- qk_data + row * kvBlockSize,
607
- scaling_factor);
663
+ // qk <- qk * scaling_factor + attn_mask, else
664
+ if (mStrideN == 0 ) {
665
+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ true >(
666
+ qk_data + row * kvBlockSize,
667
+ mask_data + i * mStrideB + j * mStrideH +
668
+ (m + row) * mStrideM ,
669
+ kvBlockSize,
670
+ qk_data + row * kvBlockSize,
671
+ scaling_factor);
672
+ } else {
673
+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ false >(
674
+ qk_data + row * kvBlockSize,
675
+ mask_data + i * mStrideB + j * mStrideH +
676
+ (m + row) * mStrideM + n,
677
+ kvBlockSize,
678
+ qk_data + row * kvBlockSize,
679
+ scaling_factor);
680
+ }
608
681
}
609
682
}
610
683
// Update coefficients with Softmax
@@ -737,6 +810,12 @@ cpu_flash_attention(
737
810
int64_t num_head = query.size (2 );
738
811
int64_t headSize = query.size (3 );
739
812
813
+ // reshape mask
814
+ if (attention_mask.has_value ()) {
815
+ reshape_attn_mask_to_4d (
816
+ attention_mask.value (), batchSize, num_head, qSize, kvSize);
817
+ }
818
+
740
819
// Strides
741
820
int64_t qStrideB = query.stride (0 );
742
821
int64_t qStrideM = query.stride (1 );
@@ -762,7 +841,13 @@ cpu_flash_attention(
762
841
? attention_mask.value ().stride (1 )
763
842
: 0 ;
764
843
int64_t mStrideM =
765
- attention_mask.has_value () ? attention_mask.value ().stride (2 ) : 0 ;
844
+ (attention_mask.has_value () && attention_mask.value ().size (2 ) > 1 )
845
+ ? attention_mask.value ().stride (2 )
846
+ : 0 ;
847
+ int64_t mStrideN =
848
+ (attention_mask.has_value () && attention_mask.value ().size (3 ) > 1 )
849
+ ? attention_mask.value ().stride (3 )
850
+ : 0 ;
766
851
767
852
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
768
853
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
@@ -1241,15 +1326,24 @@ cpu_flash_attention(
1241
1326
// And apply scaling factor
1242
1327
if (attention_mask.has_value ()) {
1243
1328
for (int64_t row = 0 ; row < qBlockSize; ++row) {
1244
- // qk <- attn_mask ? qk : -inf, if attn_mask is bool
1245
- // qk <- qk + attn_mask, else
1246
- _scale_attn_mask_fusion_kernel (
1247
- qk_data + row * kvBlockSize,
1248
- mask_data + i * mStrideB + j * mStrideH +
1249
- (m + row) * mStrideM + n,
1250
- kvBlockSize,
1251
- qk_data + row * kvBlockSize,
1252
- scaling_factor);
1329
+ // qk <- qk * scaling_factor + attn_mask, else
1330
+ if (mStrideN == 0 ) {
1331
+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ true >(
1332
+ qk_data + row * kvBlockSize,
1333
+ mask_data + i * mStrideB + j * mStrideH +
1334
+ (m + row) * mStrideM ,
1335
+ kvBlockSize,
1336
+ qk_data + row * kvBlockSize,
1337
+ scaling_factor);
1338
+ } else {
1339
+ _scale_attn_mask_fusion_kernel</* is_stride_zero*/ false >(
1340
+ qk_data + row * kvBlockSize,
1341
+ mask_data + i * mStrideB + j * mStrideH +
1342
+ (m + row) * mStrideM + n,
1343
+ kvBlockSize,
1344
+ qk_data + row * kvBlockSize,
1345
+ scaling_factor);
1346
+ }
1253
1347
}
1254
1348
}
1255
1349
// Update coefficients with Softmax
@@ -1558,6 +1652,8 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
1558
1652
attention_mask.value ().stride (-1 ) == 1 ),
1559
1653
" IPEX flash_attention: Q/K/V/Mask should be continuous on the last dim" );
1560
1654
1655
+ std::optional<at::Tensor> attn_mask =
1656
+ convert_boolean_attn_mask (attention_mask, query.dtype ());
1561
1657
at::Tensor output =
1562
1658
at::empty ({batchSize, qSize, num_head, headSize}, query.options ());
1563
1659
const auto accumulate_dtype = at::toOpMathType (dtype);
@@ -1572,7 +1668,7 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
1572
1668
value,
1573
1669
dropout_p,
1574
1670
is_causal,
1575
- attention_mask ,
1671
+ attn_mask ,
1576
1672
scale);
1577
1673
1578
1674
output = output.transpose (1 , 2 );
0 commit comments