Skip to content

Commit bee5ab6

Browse files
authored
iakv: fix tiling issue (#3246)
1 parent 1112def commit bee5ab6

File tree

2 files changed

+594
-629
lines changed

2 files changed

+594
-629
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ inline void reduce_head(
4949
}
5050
attn_w_pos[0] += _mm512_reduce_add_ps(qk_sum_vec);
5151
for (; hsi < head_size; hsi++) {
52-
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
52+
if (store_key) {
53+
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
54+
}
5355
attn_w_pos[0] += q_ptr_start[hsi] * k_ptr_start[hsi];
5456
}
5557
return;
@@ -80,7 +82,9 @@ inline void reduce_head(
8082
}
8183
attn_w_pos[0] += _mm512_reduce_add_ps(qk_sum_vec);
8284
for (; hsi < head_size; hsi++) {
83-
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
85+
if (store_key) {
86+
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
87+
}
8488
attn_w_pos[0] += q_ptr_start[hsi] * k_ptr_start[hsi];
8589
}
8690
return;
@@ -111,7 +115,9 @@ inline void reduce_head(
111115
}
112116
attn_w_pos[0] += _mm512_reduce_add_ps(qk_sum_vec);
113117
for (; hsi < head_size; hsi++) {
114-
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
118+
if (store_key) {
119+
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
120+
}
115121
attn_w_pos[0] += q_ptr_start[hsi] * k_ptr_start[hsi];
116122
}
117123
return;
@@ -139,7 +145,9 @@ inline void reduce_head_half(
139145
}
140146
attn_w_pos[0] += _mm512_reduce_add_ph(qk_sum_vec);
141147
for (; hsi < head_size; hsi++) {
142-
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
148+
if (store_key) {
149+
k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
150+
}
143151
attn_w_pos[0] += q_ptr_start[hsi] * k_ptr_start[hsi];
144152
}
145153
}

0 commit comments

Comments
 (0)