Skip to content

Commit 69cd94e

Browse files
committed
update embedding/retrieval/gritlm examples for pooling changes
1 parent f8c5fcb commit 69cd94e

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

examples/embedding/embedding.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
1717
return lines;
1818
}
1919

20-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
21-
for (size_t i = 0; i < tokens.size(); i++) {
22-
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
20+
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
21+
switch (pooling_type) {
22+
case LLAMA_POOLING_TYPE_MEAN:
23+
case LLAMA_POOLING_TYPE_NONE:
24+
return true;
25+
case LLAMA_POOLING_TYPE_CLS:
26+
return pos == 0;
27+
case LLAMA_POOLING_TYPE_LAST:
28+
return pos == n_tokens - 1;
29+
default:
30+
GGML_ASSERT(false && "unsupported pooling type");
31+
}
32+
}
33+
34+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
35+
int n_tokens = tokens.size();
36+
for (size_t i = 0; i < n_tokens; i++) {
37+
bool logit = needs_logit(pooling_type, i, n_tokens);
38+
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
2339
}
2440
}
2541

@@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4056

4157
// try to get sequence embeddings - supported only when pooling_type is not NONE
4258
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
43-
if (embd == NULL) {
44-
embd = llama_get_embeddings_ith(ctx, i);
45-
if (embd == NULL) {
46-
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
47-
continue;
48-
}
49-
}
59+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
5060

5161
float * out = output + batch.seq_id[i][0] * n_embd;
5262
//TODO: I would also add a parameter here to enable normalization or not.
@@ -99,6 +109,12 @@ int main(int argc, char ** argv) {
99109
const int n_ctx_train = llama_n_ctx_train(model);
100110
const int n_ctx = llama_n_ctx(ctx);
101111

112+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
113+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
114+
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
115+
return 1;
116+
}
117+
102118
if (n_ctx > n_ctx_train) {
103119
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
104120
__func__, n_ctx_train, n_ctx);
@@ -178,7 +194,7 @@ int main(int argc, char ** argv) {
178194
}
179195

180196
// add to batch
181-
batch_add_seq(batch, inp, s);
197+
batch_add_seq(batch, inp, s, pooling_type);
182198
s += 1;
183199
}
184200

examples/gritlm/gritlm.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,13 @@ int main(int argc, char * argv[]) {
164164

165165
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
166166

167-
// create new context - set to embedding mode
167+
// create generation context
168+
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
169+
170+
// create embedding context
168171
cparams.embeddings = true;
169-
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
172+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
173+
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
170174

171175
// ### Embedding/Representation ###
172176
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -184,8 +188,8 @@ int main(int argc, char * argv[]) {
184188
};
185189

186190
// No need to add instruction for retrieval documents
187-
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
188-
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
191+
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
192+
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
189193

190194
const int n_embd = llama_n_embd(mdl);
191195

@@ -204,10 +208,11 @@ int main(int argc, char * argv[]) {
204208
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
205209
{
206210
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
207-
std::string response = generate(ctx, prompt, true);
211+
std::string response = generate(ctx_gen, prompt, true);
208212
}
209213

210-
llama_free(ctx);
214+
llama_free(ctx_gen);
215+
llama_free(ctx_emb);
211216
llama_free_model(mdl);
212217
llama_backend_free();
213218

examples/retrieval/retrieval.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
133133
return chunks;
134134
}
135135

136-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
136+
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
137+
switch (pooling_type) {
138+
case LLAMA_POOLING_TYPE_MEAN:
139+
case LLAMA_POOLING_TYPE_NONE:
140+
return true;
141+
case LLAMA_POOLING_TYPE_CLS:
142+
return pos == 0;
143+
case LLAMA_POOLING_TYPE_LAST:
144+
return pos == n_tokens - 1;
145+
default:
146+
GGML_ASSERT(false && "unsupported pooling type");
147+
}
148+
}
149+
150+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
151+
int n_tokens = tokens.size();
137152
for (size_t i = 0; i < tokens.size(); i++) {
138-
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
153+
bool logit = needs_logit(pooling_type, i, n_tokens);
154+
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
139155
}
140156
}
141157

@@ -217,6 +233,7 @@ int main(int argc, char ** argv) {
217233

218234
const int n_ctx_train = llama_n_ctx_train(model);
219235
const int n_ctx = llama_n_ctx(ctx);
236+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
220237

221238
if (n_ctx > n_ctx_train) {
222239
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
@@ -288,7 +305,7 @@ int main(int argc, char ** argv) {
288305
}
289306

290307
// add to batch
291-
batch_add_seq(batch, inp, s);
308+
batch_add_seq(batch, inp, s, pooling_type);
292309
s += 1;
293310
}
294311

@@ -311,7 +328,7 @@ int main(int argc, char ** argv) {
311328
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
312329

313330
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
314-
batch_add_seq(query_batch, query_tokens, 0);
331+
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
315332

316333
std::vector<float> query_emb(n_embd, 0);
317334
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);

0 commit comments

Comments
 (0)