Skip to content

Commit 70321a1

Browse files
committed
kv-cache : simplify the interface (wip) [no ci]
1 parent a4090d1 commit 70321a1

File tree

9 files changed

+52
-131
lines changed

9 files changed

+52
-131
lines changed

examples/simple-chat/simple-chat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
9898
auto generate = [&](const std::string & prompt) {
9999
std::string response;
100100

101-
const bool is_first = llama_kv_self_used_cells(ctx) == 0;
101+
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
102102

103103
// tokenize the prompt
104104
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
113113
while (true) {
114114
// check if we have enough space in the context to evaluate this batch
115115
int n_ctx = llama_n_ctx(ctx);
116-
int n_ctx_used = llama_kv_self_used_cells(ctx);
116+
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
117117
if (n_ctx_used + batch.n_tokens > n_ctx) {
118118
printf("\033[0m\n");
119119
fprintf(stderr, "context size exceeded\n");

include/llama.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,12 @@ extern "C" {
610610

611611
// Returns the number of tokens in the KV cache (slow, use only for debug)
612612
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
613-
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
613+
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
614+
"Use llama_kv_self_seq_pos_max() instead");
614615

615616
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
616-
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
617+
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
618+
"Use llama_kv_self_seq_pos_max() instead");
617619

618620
// Clear the KV cache - both cell info is erased and KV data is zeroed
619621
LLAMA_API void llama_kv_self_clear(

src/llama-batch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
283283
if (!batch.pos) {
284284
pos.resize(batch.n_tokens);
285285
for (int32_t i = 0; i < batch.n_tokens; i++) {
286-
pos[i] = i + p0;
286+
pos[i] = p0 + i + 1;
287287
}
288288
batch.pos = pos.data();
289289
}

src/llama-context.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) {
857857
return -1;
858858
}
859859

860+
if (!inp_batch.pos) {
861+
if (inp_batch.seq_id) {
862+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863+
return -1;
864+
}
865+
}
866+
860867
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
861868

862869
// temporary allocate memory for the input batch if needed
863-
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
864-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
870+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0));
865871

866872
const llama_batch & batch = batch_allocr.batch;
867873

@@ -2292,22 +2298,26 @@ int32_t llama_apply_adapter_cvec(
22922298
// kv cache
22932299
//
22942300

2301+
// deprecated
22952302
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
22962303
const auto * kv = ctx->get_kv_self();
22972304
if (!kv) {
22982305
return 0;
22992306
}
23002307

2301-
return kv->get_n_tokens();
2308+
#pragma message("implement me")
2309+
return 0;
23022310
}
23032311

2312+
// deprecated
23042313
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
23052314
const auto * kv = ctx->get_kv_self();
23062315
if (!kv) {
23072316
return 0;
23082317
}
23092318

2310-
return kv->get_used_cells();
2319+
#pragma message("implement me")
2320+
return 0;
23112321
}
23122322

23132323
void llama_kv_self_clear(llama_context * ctx) {

src/llama-kv-cache.cpp

Lines changed: 15 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030
bool v_trans,
3131
bool offload,
3232
uint32_t kv_size,
33-
uint32_t padding,
33+
uint32_t n_seq_max,
34+
uint32_t n_pad,
3435
uint32_t n_swa,
35-
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
36-
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
36+
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
37+
GGML_ASSERT(kv_size % n_pad == 0 && "kv_size must be a multiple of padding");
3738

3839
this->type_k = type_k;
3940
this->type_v = type_v;
@@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
442443
void llama_kv_cache_unified::defrag_sched(float thold) {
443444
// - do not defrag small contexts (i.e. < 2048 tokens)
444445
// - count the padding towards the number of used tokens
445-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
446+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
446447

447448
// queue defragmentation for next llama_kv_cache_update
448449
if (fragmentation > thold) {
@@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
558559
// a heuristic, to avoid attending the full cache if it is not yet utilized
559560
// after enough generations, the benefit from this heuristic disappears
560561
// if we start defragmenting the cache, the benefit from this will be more important
561-
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
562+
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
562563

563564
#ifdef FIND_SLOT_DEBUG
564565
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
567568
return true;
568569
}
569570

570-
int32_t llama_kv_cache_unified::get_n_tokens() const {
571-
int32_t result = 0;
572-
573-
for (uint32_t i = 0; i < size; i++) {
574-
result += cells[i].seq_id.size();
575-
}
576-
577-
return result;
578-
}
579-
580-
int32_t llama_kv_cache_unified::get_used_cells() const {
581-
return used;
582-
}
583-
584571
bool llama_kv_cache_unified::get_can_shift() const {
585572
return true;
586573
}
@@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
802789
}
803790
}
804791

805-
llama_pos llama_kv_cache_unified::get_pos_max() const {
806-
llama_pos pos_max = -1;
807-
808-
for (const auto & cell : cells) {
809-
pos_max = std::max(pos_max, cell.pos);
810-
}
811-
812-
return pos_max;
813-
}
814-
815792
size_t llama_kv_cache_unified::total_size() const {
816793
size_t size = 0;
817794

@@ -1655,17 +1632,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16551632
ggml_type type_v,
16561633
bool v_trans,
16571634
bool offload,
1658-
uint32_t kv_size,
16591635
bool swa_full,
1636+
uint32_t kv_size,
16601637
uint32_t n_seq_max,
16611638
uint32_t n_batch,
1662-
uint32_t padding) : hparams(model.hparams) {
1639+
uint32_t n_pad) : hparams(model.hparams) {
16631640
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
16641641
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
16651642

16661643
const uint32_t size_base = kv_size;
16671644

1668-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1645+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
16691646

16701647
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
16711648
if (swa_full) {
@@ -1680,14 +1657,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16801657

16811658
kv_base = std::make_unique<llama_kv_cache_unified>(
16821659
model, std::move(filter_base), type_k, type_v,
1683-
v_trans, offload, size_base, padding,
1660+
v_trans, offload, size_base, n_seq_max, n_pad,
16841661
0, LLAMA_SWA_TYPE_NONE);
16851662

16861663
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
16871664

16881665
kv_swa = std::make_unique<llama_kv_cache_unified>(
16891666
model, std::move(filter_swa), type_k, type_v,
1690-
v_trans, offload, size_swa, padding,
1667+
v_trans, offload, size_swa, n_seq_max, n_pad,
16911668
hparams.n_swa, hparams.swa_type);
16921669
}
16931670

@@ -1810,18 +1787,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
18101787
return res;
18111788
}
18121789

1813-
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
1814-
return kv_base->get_n_tokens();
1815-
}
1816-
1817-
int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
1818-
return kv_base->get_used_cells();
1819-
}
1820-
1821-
llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
1822-
return kv_base->get_pos_max();
1823-
}
1824-
18251790
bool llama_kv_cache_unified_iswa::get_can_shift() const {
18261791
return kv_base->get_size() == kv_swa->get_size();
18271792
}
@@ -1853,7 +1818,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18531818
ggml_type type_k,
18541819
ggml_type type_v,
18551820
bool offload,
1856-
uint32_t kv_size) : hparams(model.hparams) {
1821+
uint32_t kv_size,
1822+
uint32_t n_seq_max) : hparams(model.hparams) {
18571823
const int32_t n_layer = hparams.n_layer;
18581824

18591825
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -2203,8 +2169,8 @@ void llama_kv_cache_recurrent::commit() {
22032169
pending.ranges.clear();
22042170
}
22052171

2206-
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
2207-
GGML_UNUSED(lctx);
2172+
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2173+
GGML_UNUSED(ctx);
22082174
return false;
22092175
}
22102176

@@ -2408,29 +2374,6 @@ bool llama_kv_cache_recurrent::find_slot(
24082374
return n >= n_seqs;
24092375
}
24102376

2411-
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
2412-
int32_t result = 0;
2413-
2414-
for (uint32_t i = 0; i < size; i++) {
2415-
result += cells[i].seq_id.size();
2416-
}
2417-
2418-
return result;
2419-
}
2420-
2421-
int32_t llama_kv_cache_recurrent::get_used_cells() const {
2422-
return used;
2423-
}
2424-
2425-
llama_pos llama_kv_cache_recurrent::get_pos_max() const {
2426-
llama_pos pos_max = -1;
2427-
for (const auto & cell : cells) {
2428-
pos_max = std::max(pos_max, cell.pos);
2429-
}
2430-
2431-
return pos_max;
2432-
}
2433-
24342377
bool llama_kv_cache_recurrent::get_can_shift() const {
24352378
return false;
24362379
}

src/llama-kv-cache.h

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ struct llama_kv_cache : public llama_memory_i {
5555
// =============================================================================================================
5656

5757
// getters
58-
virtual int32_t get_n_tokens() const = 0;
59-
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
60-
virtual llama_pos get_pos_max() const = 0;
61-
virtual bool get_can_shift() const = 0;
58+
virtual bool get_can_shift() const = 0;
6259

6360
bool get_can_edit() const override { return get_can_shift(); }
6461

@@ -108,7 +105,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
108105
bool v_trans,
109106
bool offload,
110107
uint32_t kv_size,
111-
uint32_t padding,
108+
uint32_t n_seq_max,
109+
uint32_t n_pad,
112110
uint32_t n_swa,
113111
llama_swa_type swa_type);
114112

@@ -150,12 +148,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
150148
// to the first cell of the slot.
151149
bool find_slot(const llama_ubatch & batch) override;
152150

153-
int32_t get_n_tokens() const override;
154-
int32_t get_used_cells() const override;
155-
156-
// TODO: better data structures to reduce the cost of this operation
157-
llama_pos get_pos_max() const override;
158-
159151
bool get_can_shift() const override;
160152

161153
// state write/load
@@ -229,7 +221,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
229221
uint32_t n = 0;
230222

231223
// required padding
232-
uint32_t padding = 1;
224+
uint32_t n_pad = 1;
233225

234226
ggml_type type_k = GGML_TYPE_F16;
235227
ggml_type type_v = GGML_TYPE_F16;
@@ -317,11 +309,11 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
317309
ggml_type type_v,
318310
bool v_trans,
319311
bool offload,
320-
uint32_t kv_size,
321312
bool swa_full,
313+
uint32_t kv_size,
322314
uint32_t n_seq_max,
323315
uint32_t n_batch,
324-
uint32_t padding);
316+
uint32_t n_pad);
325317

326318
~llama_kv_cache_unified_iswa() = default;
327319

@@ -358,12 +350,6 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
358350

359351
bool find_slot(const llama_ubatch & batch) override;
360352

361-
int32_t get_n_tokens() const override;
362-
int32_t get_used_cells() const override;
363-
364-
// TODO: better data structures to reduce the cost of this operation
365-
llama_pos get_pos_max() const override;
366-
367353
bool get_can_shift() const override;
368354

369355
// state write/load
@@ -432,7 +418,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
432418
ggml_type type_k,
433419
ggml_type type_v,
434420
bool offload,
435-
uint32_t kv_size);
421+
uint32_t kv_size,
422+
uint32_t n_seq_max);
436423

437424
~llama_kv_cache_recurrent() = default;
438425

@@ -444,7 +431,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
444431

445432
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
446433
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
447-
void seq_keep(llama_seq_id seq_id) override;
434+
void seq_keep(llama_seq_id seq_id) override;
448435
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
449436
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
450437

@@ -458,7 +445,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
458445
void restore() override;
459446
void commit() override;
460447

461-
bool update(llama_context & lctx) override;
448+
bool update(llama_context & ctx) override;
462449

463450
void defrag_sched(float thold) override;
464451

@@ -469,12 +456,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
469456

470457
bool find_slot(const llama_ubatch & batch) override;
471458

472-
int32_t get_n_tokens() const override;
473-
int32_t get_used_cells() const override;
474-
475-
// TODO: better data structures to reduce the cost of this operation
476-
llama_pos get_pos_max() const override;
477-
478459
bool get_can_shift() const override;
479460

480461
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this

0 commit comments

Comments
 (0)