diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 410d2608798b8..47c382f3d87f2 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1156,6 +1156,10 @@ ggml_tensor * llm_graph_context::build_attn_mha( // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA if (v_mla) { kqv = ggml_mul_mat(ctx0, v_mla, kqv); + // all nodes between the KV store and the attention output are run on the CPU + if (!cparams.offload_kqv) { + ggml_backend_sched_set_tensor_backend(sched, kqv, backend_cpu); + } } cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);