Skip to content

Commit 2dba70d

Browse files
committed
context : enable reranking with encode()
ggml-ci
1 parent 5f5c3b7 commit 2dba70d

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/llama-context.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -790,11 +790,18 @@ int llama_context::encode(llama_batch & inp_batch) {
790790
} break;
791791
case LLAMA_POOLING_TYPE_RANK:
792792
{
793-
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
794-
// wait for an encoder model that requires this pooling type in order to test it
795-
// https://github.com/ggerganov/llama.cpp/pull/9510
796-
GGML_ABORT("RANK pooling not implemented yet");
797-
}
793+
// extract the rerank score - a single float per sequence
794+
auto & embd_seq_out = embd_seq;
795+
796+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
797+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
798+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
799+
continue;
800+
}
801+
embd_seq_out[seq_id].resize(1);
802+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
803+
}
804+
} break;
798805
case LLAMA_POOLING_TYPE_UNSPECIFIED:
799806
{
800807
GGML_ABORT("unknown pooling type");

0 commit comments

Comments
 (0)