@@ -2951,7 +2951,8 @@ struct server_context {
2951
2951
llama_kv_self_seq_rm (ctx, slot.id , n_keep , n_keep + n_discard);
2952
2952
llama_kv_self_seq_add (ctx, slot.id , n_keep + n_discard, slot.n_past , -n_discard);
2953
2953
2954
- if (slot.params .cache_prompt ) {
2954
+ // add generated tokens to cache
2955
+ {
2955
2956
llama_tokens new_tokens = slot.cache_tokens .get_text_tokens (); // copy
2956
2957
for (size_t i = n_keep + n_discard; i < new_tokens.size (); i++) {
2957
2958
new_tokens[i - n_discard] = new_tokens[i];
@@ -2996,10 +2997,7 @@ struct server_context {
2996
2997
common_batch_add (batch, slot.sampled , slot.n_past , { slot.id }, true );
2997
2998
2998
2999
slot.n_past += 1 ;
2999
-
3000
- if (slot.params .cache_prompt ) {
3001
- slot.cache_tokens .push_back (slot.sampled );
3002
- }
3000
+ slot.cache_tokens .push_back (slot.sampled );
3003
3001
3004
3002
SLT_DBG (slot, " slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n " ,
3005
3003
slot.n_ctx , slot.n_past , (int ) slot.cache_tokens .size (), slot.truncated );
@@ -3171,6 +3169,11 @@ struct server_context {
3171
3169
3172
3170
SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
3173
3171
}
3172
+ } else {
3173
+ // if we don't cache the prompt, we have to remove the entire KV cache
3174
+ llama_kv_self_seq_rm (ctx, slot.id , 0 , -1 );
3175
+ slot.n_past = 0 ;
3176
+ slot.cache_tokens .clear ();
3174
3177
}
3175
3178
}
3176
3179
@@ -3204,7 +3207,7 @@ struct server_context {
3204
3207
SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
3205
3208
3206
3209
// remove the non-common part from the cache
3207
- slot.cache_tokens .resize (slot.n_past );
3210
+ slot.cache_tokens .keep_first (slot.n_past );
3208
3211
3209
3212
// check if we should process the image
3210
3213
if (slot.n_past < slot.n_prompt_tokens
@@ -3221,7 +3224,8 @@ struct server_context {
3221
3224
continue ;
3222
3225
}
3223
3226
3224
- if (slot.params .cache_prompt ) {
3227
+ // add the image chunk to cache
3228
+ {
3225
3229
const auto & chunk = slot.prompt_tokens .find_chunk (slot.n_past );
3226
3230
slot.cache_tokens .push_back (chunk.get ()); // copy
3227
3231
}
@@ -3242,9 +3246,7 @@ struct server_context {
3242
3246
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
3243
3247
3244
3248
common_batch_add (batch, cur_tok, slot.n_past , { slot.id }, need_embd);
3245
- if (slot.params .cache_prompt ) {
3246
- slot.cache_tokens .push_back (cur_tok);
3247
- }
3249
+ slot.cache_tokens .push_back (cur_tok);
3248
3250
3249
3251
slot.n_prompt_tokens_processed ++;
3250
3252
slot.n_past ++;
0 commit comments