@@ -49,7 +49,9 @@ inline void reduce_head(
49
49
}
50
50
attn_w_pos[0 ] += _mm512_reduce_add_ps (qk_sum_vec);
51
51
for (; hsi < head_size; hsi++) {
52
- k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
52
+ if (store_key) {
53
+ k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
54
+ }
53
55
attn_w_pos[0 ] += q_ptr_start[hsi] * k_ptr_start[hsi];
54
56
}
55
57
return ;
@@ -80,7 +82,9 @@ inline void reduce_head(
80
82
}
81
83
attn_w_pos[0 ] += _mm512_reduce_add_ps (qk_sum_vec);
82
84
for (; hsi < head_size; hsi++) {
83
- k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
85
+ if (store_key) {
86
+ k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
87
+ }
84
88
attn_w_pos[0 ] += q_ptr_start[hsi] * k_ptr_start[hsi];
85
89
}
86
90
return ;
@@ -111,7 +115,9 @@ inline void reduce_head(
111
115
}
112
116
attn_w_pos[0 ] += _mm512_reduce_add_ps (qk_sum_vec);
113
117
for (; hsi < head_size; hsi++) {
114
- k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
118
+ if (store_key) {
119
+ k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
120
+ }
115
121
attn_w_pos[0 ] += q_ptr_start[hsi] * k_ptr_start[hsi];
116
122
}
117
123
return ;
@@ -139,7 +145,9 @@ inline void reduce_head_half(
139
145
}
140
146
attn_w_pos[0 ] += _mm512_reduce_add_ph (qk_sum_vec);
141
147
for (; hsi < head_size; hsi++) {
142
- k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
148
+ if (store_key) {
149
+ k_cache_start[hsi] = k_ptr_start[hsi]; // cat the key into the key_cache.
150
+ }
143
151
attn_w_pos[0 ] += q_ptr_start[hsi] * k_ptr_start[hsi];
144
152
}
145
153
}
0 commit comments