Skip to content

Commit 638a7d2

Browse files
authored
fix accuracy issue on reshape_and_cache by selecting correct index (#3307)
1 parent 91639fa commit 638a7d2

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -555,20 +555,24 @@ void reshape_and_cache_kernel(
555555
auto cache_strideN = key_cache.stride(0);
556556
auto cache_strideP = key_cache.stride(2);
557557
auto cache_strideH = key_cache.stride(1);
558-
auto state_strideN = key.stride(0);
559-
auto state_strideH = key.stride(1);
558+
auto key_state_strideN = key.stride(0);
559+
auto key_state_strideH = key.stride(1);
560+
auto value_state_strideN = value.stride(0);
561+
auto value_state_strideH = value.stride(1);
560562
#pragma omp parallel for collapse(2)
561563
for (auto ti = 0; ti < num_tokens; ti++) {
562564
for (auto hi = 0; hi < head_num; hi++) {
563565
auto physical_block_id = slot_mapping_ptr[ti] / block_size;
564566
auto block_offset = slot_mapping_ptr[ti] % block_size;
565567
auto cache_offset = physical_block_id * cache_strideN +
566568
block_offset * cache_strideP + hi * cache_strideH;
567-
auto state_offset = ti * state_strideN + hi * state_strideH;
569+
auto key_state_offset = ti * key_state_strideN + hi * key_state_strideH;
570+
auto value_state_offset =
571+
ti * value_state_strideN + hi * value_state_strideH;
568572
auto key_cache_start = key_cache_ptr + cache_offset;
569-
auto key_ptr_start = key_ptr + state_offset;
573+
auto key_ptr_start = key_ptr + key_state_offset;
570574
auto value_cache_start = value_cache_ptr + cache_offset;
571-
auto value_ptr_start = value_ptr + state_offset;
575+
auto value_ptr_start = value_ptr + value_state_offset;
572576
torch_ipex::cpu::kernel::move_ker<DST_T, SRC_T>(
573577
key_cache_start, key_ptr_start, head_size);
574578
torch_ipex::cpu::kernel::move_ker<DST_T, SRC_T>(

tests/cpu/test_paged_attention.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def _test_reshape_and_cache_func(
252252
num_blocks: int,
253253
dtype: torch.dtype,
254254
seed: int,
255+
key_is_contiguous: bool,
256+
value_is_contiguous: bool,
255257
) -> None:
256258
random.seed(seed)
257259
torch.random.manual_seed(seed)
@@ -264,6 +266,13 @@ def _test_reshape_and_cache_func(
264266

265267
qkv = torch.randn(num_token, 3, num_head, head_size, dtype=dtype, device="cpu")
266268
_, key, value = qkv.unbind(dim=1)
269+
if key.shape[0] != 1:
270+
if not key_is_contiguous:
271+
key = key.transpose(0, 1).contiguous()
272+
key = key.transpose(0, 1)
273+
if not value_is_contiguous:
274+
value = value.transpose(0, 1).contiguous()
275+
value = value.transpose(0, 1)
267276
# Create the KV caches.
268277
key_caches, value_caches = self.create_kv_caches(
269278
num_blocks, block_size, 1, num_head, head_size, dtype, seed
@@ -300,6 +309,8 @@ def test_reshape_and_cache(self):
300309
head_sizes = [64, 80, 128, 96, 112, 128, 256]
301310
block_sizes = [16, 32]
302311
dtypes = [torch.bfloat16, torch.float]
312+
key_modes = [True, False]
313+
value_modes = [True, False]
303314
if core.onednn_has_fp16_support():
304315
dtypes.append(torch.float16)
305316
seeds = [0]
@@ -310,16 +321,28 @@ def test_reshape_and_cache(self):
310321
block_size,
311322
dtype,
312323
seed,
324+
key_is_contiguous,
325+
value_is_contiguous,
313326
) in product(
314327
num_tokens,
315328
num_kv_heads,
316329
head_sizes,
317330
block_sizes,
318331
dtypes,
319332
seeds,
333+
key_modes,
334+
value_modes,
320335
):
321336
self._test_reshape_and_cache_func(
322-
num_token, num_kv_head, head_size, block_size, num_blocks, dtype, seed
337+
num_token,
338+
num_kv_head,
339+
head_size,
340+
block_size,
341+
num_blocks,
342+
dtype,
343+
seed,
344+
key_is_contiguous,
345+
value_is_contiguous,
323346
)
324347

325348

0 commit comments

Comments
 (0)