Skip to content

Commit a059161

Browse files
committed
llama : auto-batch
ggml-ci
1 parent 12d0188 commit a059161

File tree

5 files changed

+52
-40
lines changed

5 files changed

+52
-40
lines changed

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ int main(int argc, char ** argv) {
392392
return 1;
393393
}
394394

395-
LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
395+
LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
396396

397397
n_cache_miss += 1;
398398

src/llama-context.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424424
return kv_self;
425425
}
426426

427-
void llama_context::kv_self_update() {
427+
bool llama_context::kv_self_update() {
428428
if (!memory) {
429-
return;
429+
return false;
430430
}
431431

432432
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
@@ -445,7 +445,11 @@ void llama_context::kv_self_update() {
445445
if (!gf) {
446446
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
447447
}
448+
449+
return true;
448450
}
451+
452+
return false;
449453
}
450454

451455
enum llama_pooling_type llama_context::pooling_type() const {
@@ -933,24 +937,44 @@ int llama_context::decode(llama_batch & inp_batch) {
933937
// handle any pending defrags/shifts
934938
kv_self_update();
935939

936-
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937-
if (!kv_state) {
938-
return -2;
939-
}
940+
llama_memory_state_ptr kv_state;
940941

941-
switch (kv_state->get_status()) {
942-
case LLAMA_MEMORY_STATUS_SUCCESS:
943-
{
944-
} break;
945-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946-
{
947-
// not a fatal error, we can re-try with a different batch
948-
return 1;
949-
}
950-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951-
{
952-
return -2;
953-
}
942+
bool did_defrag = false;
943+
944+
while (true) {
945+
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
946+
if (!kv_state) {
947+
return -2;
948+
}
949+
950+
switch (kv_state->get_status()) {
951+
case LLAMA_MEMORY_STATUS_SUCCESS:
952+
{
953+
} break;
954+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
955+
{
956+
if (!did_defrag) {
957+
did_defrag = true;
958+
959+
kv_self->defrag_sched(-1.0f);
960+
if (kv_self_update()) {
961+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
962+
963+
continue;
964+
}
965+
}
966+
967+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
968+
969+
return 1;
970+
}
971+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
972+
{
973+
return -2;
974+
}
975+
}
976+
977+
break;
954978
}
955979

956980
// reserve output buffer
@@ -2646,22 +2670,8 @@ int32_t llama_encode(
26462670
int32_t llama_decode(
26472671
llama_context * ctx,
26482672
llama_batch batch) {
2649-
int ret = ctx->decode(batch);
2650-
2651-
// defrag and try again
2652-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2653-
if (ret == 1) {
2654-
llama_kv_self_defrag(ctx);
2655-
ret = ctx->decode(batch);
2656-
2657-
if (ret == 1) {
2658-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2659-
2660-
return ret;
2661-
}
2662-
}
2663-
2664-
if (ret != 0) {
2673+
const int ret = ctx->decode(batch);
2674+
if (ret != 0 && ret != 1) {
26652675
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26662676
}
26672677

src/llama-context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ struct llama_context {
5050
llama_kv_cache * get_kv_self();
5151
const llama_kv_cache * get_kv_self() const;
5252

53+
// return true of the KV cache was updated
5354
// TODO: remove
54-
void kv_self_update();
55+
bool kv_self_update();
5556

5657
enum llama_pooling_type pooling_type() const;
5758

src/llama-kv-cache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,9 +1809,10 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
18091809
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
18101810
GGML_UNUSED(embd_pooled);
18111811

1812-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
1812+
// TODO: if we fail with split_simple, we should attempt different splitting strategies
1813+
// but to do that properly, we first have to refactor the batches to be more flexible
18131814

1814-
// TODO: if we fail with split_simple, we should attempt split_equal
1815+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
18151816

18161817
std::vector<llama_ubatch> ubatches;
18171818

tools/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3431,7 +3431,7 @@ struct server_context {
34313431
// retry with half the batch size to try to find a free slot in the KV cache
34323432
n_batch /= 2;
34333433

3434-
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
3434+
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
34353435

34363436
continue; // continue loop of n_batch
34373437
}

0 commit comments

Comments
 (0)