Skip to content

Commit ca69f32

Browse files
committed
llama : auto-batch
ggml-ci
1 parent f23e4cc commit ca69f32

File tree

3 files changed

+80
-87
lines changed

3 files changed

+80
-87
lines changed

src/llama-context.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424424
return kv_self;
425425
}
426426

427-
void llama_context::kv_self_update() {
427+
bool llama_context::kv_self_update() {
428428
if (!memory) {
429-
return;
429+
return false;
430430
}
431431

432432
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
@@ -445,7 +445,11 @@ void llama_context::kv_self_update() {
445445
if (!gf) {
446446
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
447447
}
448+
449+
return true;
448450
}
451+
452+
return false;
449453
}
450454

451455
enum llama_pooling_type llama_context::pooling_type() const {
@@ -933,25 +937,53 @@ int llama_context::decode(llama_batch & inp_batch) {
933937
// handle any pending defrags/shifts
934938
kv_self_update();
935939

936-
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937-
if (!kv_state) {
938-
return -2;
939-
}
940+
llama_memory_state_ptr kv_state;
940941

941-
switch (kv_state->get_status()) {
942-
case LLAMA_MEMORY_STATUS_SUCCESS:
943-
{
944-
} break;
945-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946-
{
947-
// not a fatal error, we can re-try with a different batch
948-
return 1;
949-
}
950-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951-
{
952-
return -2;
953-
}
954-
}
942+
bool did_defrag = false;
943+
auto n_ubatch = cparams.n_ubatch;
944+
945+
do {
946+
kv_state = kv_self->init_batch(batch, n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
947+
if (!kv_state) {
948+
return -2;
949+
}
950+
951+
switch (kv_state->get_status()) {
952+
case LLAMA_MEMORY_STATUS_SUCCESS:
953+
{
954+
} break;
955+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
956+
{
957+
if (!did_defrag) {
958+
did_defrag = true;
959+
960+
kv_self->defrag_sched(-1.0f);
961+
if (kv_self_update()) {
962+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
963+
964+
continue;
965+
}
966+
}
967+
968+
if (n_ubatch > 1) {
969+
n_ubatch /= 2;
970+
971+
LLAMA_LOG_DEBUG("%s: failed to find free space in the KV cache, retrying with smaller ubatch size: n_ubatch = %d\n", __func__, n_ubatch);
972+
continue;
973+
}
974+
975+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
976+
977+
return 1;
978+
}
979+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
980+
{
981+
return -2;
982+
}
983+
}
984+
985+
break;
986+
} while(true);
955987

956988
// reserve output buffer
957989
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -2646,22 +2678,8 @@ int32_t llama_encode(
26462678
int32_t llama_decode(
26472679
llama_context * ctx,
26482680
llama_batch batch) {
2649-
int ret = ctx->decode(batch);
2650-
2651-
// defrag and try again
2652-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2653-
if (ret == 1) {
2654-
llama_kv_self_defrag(ctx);
2655-
ret = ctx->decode(batch);
2656-
2657-
if (ret == 1) {
2658-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2659-
2660-
return ret;
2661-
}
2662-
}
2663-
2664-
if (ret != 0) {
2681+
const int ret = ctx->decode(batch);
2682+
if (ret != 0 && ret != 1) {
26652683
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26662684
}
26672685

src/llama-context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ struct llama_context {
5050
llama_kv_cache * get_kv_self();
5151
const llama_kv_cache * get_kv_self() const;
5252

53+
// return true of the KV cache was updated
5354
// TODO: remove
54-
void kv_self_update();
55+
bool kv_self_update();
5556

5657
enum llama_pooling_type pooling_type() const;
5758

tools/server/server.cpp

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,75 +3385,49 @@ struct server_context {
33853385
}
33863386

33873387
// process the created batch of tokens
3388-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3389-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
3390-
3391-
llama_batch batch_view = {
3392-
n_tokens,
3393-
batch.token + i,
3394-
nullptr,
3395-
batch.pos + i,
3396-
batch.n_seq_id + i,
3397-
batch.seq_id + i,
3398-
batch.logits + i,
3399-
};
3400-
3401-
const int ret = llama_decode(ctx, batch_view);
3402-
3403-
metrics.on_decoded(slots);
3388+
{
3389+
const int ret = llama_decode(ctx, batch);
34043390

34053391
if (ret != 0) {
3406-
{
3407-
std::string err;
3408-
3409-
if (n_batch == 1 && ret == 1) {
3410-
err = "Context size has been exceeded.";
3411-
}
3412-
3413-
if (ret == -1) {
3414-
err = "Invalid input batch.";
3415-
}
3392+
std::string err;
34163393

3417-
if (ret < -1) {
3418-
err = "Compute error.";
3419-
}
3420-
3421-
if (!err.empty()) {
3422-
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
3423-
for (auto & slot : slots) {
3424-
slot.release();
3425-
send_error(slot, err);
3426-
}
3427-
break;
3428-
}
3394+
if (ret == 1) {
3395+
err = "Context size has been exceeded.";
34293396
}
34303397

3431-
// retry with half the batch size to try to find a free slot in the KV cache
3432-
n_batch /= 2;
3398+
if (ret == -1) {
3399+
err = "Invalid input batch.";
3400+
}
34333401

3434-
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
3402+
if (ret < -1) {
3403+
err = "Compute error.";
3404+
}
34353405

3436-
i -= n_batch;
3406+
if (!err.empty()) {
3407+
SRV_ERR("%s, n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret);
3408+
for (auto & slot : slots) {
3409+
slot.release();
3410+
send_error(slot, err);
3411+
}
34373412

3438-
continue; // continue loop of n_batch
3413+
return;
3414+
}
34393415
}
34403416

3441-
for (auto & slot : slots) {
3442-
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
3443-
continue; // continue loop of slots
3444-
}
3417+
metrics.on_decoded(slots);
34453418

3419+
for (auto & slot : slots) {
34463420
if (slot.state == SLOT_STATE_DONE_PROMPT) {
34473421
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
34483422
// prompt evaluated for embedding
3449-
send_embedding(slot, batch_view);
3423+
send_embedding(slot, batch);
34503424
slot.release();
34513425
slot.i_batch = -1;
34523426
continue; // continue loop of slots
34533427
}
34543428

34553429
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
3456-
send_rerank(slot, batch_view);
3430+
send_rerank(slot, batch);
34573431
slot.release();
34583432
slot.i_batch = -1;
34593433
continue; // continue loop of slots
@@ -3465,7 +3439,7 @@ struct server_context {
34653439
continue; // continue loop of slots
34663440
}
34673441

3468-
const int tok_idx = slot.i_batch - i;
3442+
const int tok_idx = slot.i_batch;
34693443

34703444
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
34713445

0 commit comments

Comments
 (0)