Skip to content

Commit 655e9c9

Browse files
ngxsonjwcolin
authored andcommitted
mtmd : add methods to access mtmd_image_tokens (ggml-org#12906)
* mtmd : add more api around mtmd_image_tokens * mtmd : ability to calc image hash * shared_ptr for mtmd_image_tokens * move hash to user-define ID (fixed) * fix prompt_modified * rm redundant data member
1 parent c2e0025 commit 655e9c9

File tree

3 files changed

+92
-44
lines changed

3 files changed

+92
-44
lines changed

examples/llava/gemma3-cli.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
184184
text.text = formatted_chat.prompt;
185185
text.add_special = add_bos;
186186
text.parse_special = true;
187-
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
188-
if (chunks == nullptr) {
189-
LOG_ERR("Unable to tokenize prompt\n");
187+
mtmd_input_chunks chunks;
188+
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
189+
if (res != 0) {
190+
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
190191
return 1;
191192
}
192193

193-
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
194+
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
194195
LOG_ERR("Unable to eval prompt\n");
195196
return 1;
196197
}
197198

198-
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
199+
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
199200

200201
return 0;
201202
}

examples/llava/mtmd.cpp

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct mtmd_context {
1616
struct clip_ctx * ctx_clip;
1717
const struct llama_model * text_model;
1818
std::vector<float> image_embd_v; // image embedding vector
19+
1920
bool print_timings;
2021
int n_threads;
2122
std::string image_marker;
@@ -24,7 +25,11 @@ struct mtmd_context {
2425

2526
mtmd_context(const char * mmproj_fname,
2627
const llama_model * text_model,
27-
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
28+
const mtmd_context_params & ctx_params) :
29+
print_timings(ctx_params.print_timings),
30+
n_threads (ctx_params.n_threads),
31+
image_marker (ctx_params.image_marker)
32+
{
2833
clip_context_params ctx_clip_params;
2934
ctx_clip_params.use_gpu = ctx_params.use_gpu;
3035
ctx_clip_params.verbosity = ctx_params.verbosity;
@@ -49,6 +54,7 @@ struct mtmd_image_tokens {
4954
uint32_t ny; // number of tokens in y direction
5055
uint32_t n_tokens() const { return nx * ny; }
5156
clip_image_f32_batch batch_f32; // preprocessed image patches
57+
std::string id; // optional user-defined ID, useful for KV cache tracking
5258
};
5359

5460
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
@@ -88,10 +94,10 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
8894
return result;
8995
}
9096

91-
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
92-
const mtmd_input_text & text,
93-
const std::vector<mtmd_bitmap> & bitmaps) {
94-
mtmd_input_chunks * output = new mtmd_input_chunks;
97+
int32_t mtmd_tokenize(mtmd_context * ctx,
98+
std::vector<mtmd_input_chunk> & output,
99+
const mtmd_input_text & text,
100+
const std::vector<mtmd_bitmap> & bitmaps) {
95101
auto vocab = llama_model_get_vocab(ctx->text_model);
96102

97103
std::string prompt_modified(text.text);
@@ -105,9 +111,9 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
105111
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
106112
}
107113

108-
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
109-
output->clear();
110-
output->reserve(parts.size());
114+
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
115+
output.clear();
116+
output.reserve(parts.size());
111117

112118
size_t i_img = 0;
113119

@@ -123,14 +129,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
123129
std::move(tokens),
124130
{},
125131
};
126-
output->emplace_back(std::move(chunk));
132+
output.emplace_back(std::move(chunk));
127133

128134
if (&parts.back() != &part) {
129135
// add image token to middle of 2 parts
130136

131137
if (i_img >= bitmaps.size()) {
132138
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
133-
return nullptr;
139+
return 1;
134140
}
135141

136142
// shim layer
@@ -145,34 +151,48 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
145151
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
146152
if (!ok) {
147153
LOG_ERR("Unable to preprocess image\n");
148-
return nullptr;
154+
return 2;
149155
}
150156

151-
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
157+
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
152158
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
153159
image_tokens->ny = 1; // TODO
154160
image_tokens->batch_f32 = std::move(batch_f32);
161+
image_tokens->id = bitmaps[i_img].id; // optional
155162

156163
mtmd_input_chunk chunk{
157164
MTMD_INPUT_CHUNK_TYPE_IMAGE,
158165
{},
159-
image_tokens,
166+
std::move(image_tokens),
160167
};
161-
output->emplace_back(std::move(chunk));
168+
output.emplace_back(std::move(chunk));
162169
i_img++;
163170
}
164171
}
165172

166-
return output;
173+
return 0;
167174
}
168175

169-
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
170-
for (auto & chunk : *chunks) {
171-
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
172-
delete chunk.tokens_image;
173-
}
176+
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
177+
if (image_tokens) {
178+
delete image_tokens;
174179
}
175-
delete chunks;
180+
}
181+
182+
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
183+
return image_tokens->n_tokens();
184+
}
185+
186+
size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
187+
return image_tokens->nx;
188+
}
189+
190+
size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
191+
return image_tokens->ny;
192+
}
193+
194+
std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
195+
return image_tokens->id;
176196
}
177197

178198
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@@ -190,9 +210,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
190210
return ctx->image_embd_v.data();
191211
}
192212

193-
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
213+
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
194214
size_t n_tokens = 0;
195-
for (auto & chunk : *chunks) {
215+
for (auto & chunk : chunks) {
196216
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
197217
n_tokens += chunk.tokens_text.size();
198218
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -241,16 +261,16 @@ struct decode_embd_batch {
241261

242262
int32_t mtmd_helper_eval(mtmd_context * ctx,
243263
llama_context * lctx,
244-
mtmd_input_chunks * chunks,
264+
mtmd_input_chunks & chunks,
245265
llama_pos pos0,
246266
llama_seq_id seq_id,
247267
int32_t n_batch) {
248268
int32_t ret;
249269
llama_pos n_past = pos0;
250270
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
251271

252-
for (auto & chunk : *chunks) {
253-
bool is_last = &chunk == &chunks->back();
272+
for (auto & chunk : chunks) {
273+
bool is_last = &chunk == &chunks.back();
254274
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
255275
// TODO @ngxson : may need to split into smaller batches
256276
text_batch.n_tokens = chunk.tokens_text.size();
@@ -279,7 +299,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
279299
if (ctx->print_timings) {
280300
LOG_INF("encoding image...\n");
281301
}
282-
ret = mtmd_encode(ctx, chunk.tokens_image);
302+
ret = mtmd_encode(ctx, chunk.tokens_image.get());
283303
if (ret != 0) {
284304
LOG_ERR("failed to encode image\n");
285305
llama_batch_free(text_batch);
@@ -289,7 +309,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
289309
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
290310
}
291311

292-
int32_t n_tokens = chunk.tokens_image->n_tokens();
312+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
293313
float * embd = mtmd_get_output_embd(ctx);
294314
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
295315
int64_t t1 = ggml_time_ms();
@@ -339,3 +359,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
339359
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
340360
return 0;
341361
}
362+
363+
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
364+
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
365+
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
366+
return true;
367+
}
368+
return false;
369+
}
370+
371+
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
372+
mtmd_image_tokens_free(val);
373+
}

examples/llava/mtmd.h

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,18 @@ struct mtmd_bitmap {
3939
uint32_t nx;
4040
uint32_t ny;
4141
std::vector<unsigned char> data;
42+
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
4243
};
4344

45+
struct mtmd_image_tokens_deleter {
46+
void operator()(mtmd_image_tokens * val); // forward declaration
47+
};
48+
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
49+
4450
struct mtmd_input_chunk {
4551
mtmd_input_chunk_type type;
4652
std::vector<llama_token> tokens_text;
47-
mtmd_image_tokens * tokens_image = nullptr;
53+
mtmd_image_tokens_ptr tokens_image;
4854
};
4955

5056
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
@@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
8288
// 3. "<end_of_image>\ndescribe it in detail."
8389
// number of bitmaps must be equal to the number of image markers in the prompt
8490
// this function is thread-safe (shared ctx)
85-
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
91+
// return values:
92+
// 0 on success
93+
// 1 on number of images not matching the number of markers
94+
// 2 on image preprocessing error
95+
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
96+
std::vector<mtmd_input_chunk> & output,
8697
const mtmd_input_text & text,
8798
const std::vector<mtmd_bitmap> & bitmaps);
8899

89-
// free image chunk data
90-
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
100+
// access mtmd_image_tokens
101+
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
102+
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
103+
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
104+
MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
105+
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
91106

92107
// returns 0 on success
93108
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
@@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
96111
// get output embeddings from the last encode pass
97112
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
98113

114+
// whether we need to set non-causal mask before llama_decode
115+
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
116+
117+
118+
99119
//
100120
// helper functions (can be implemented based on other functions)
101121
//
102122

103123
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
104-
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
124+
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
105125

106126
// helper function that automatically:
107127
// 1. run llama_decode() on text chunks
@@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
110130
// otherwise, returns 0 on success
111131
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
112132
llama_context * lctx,
113-
mtmd_input_chunks * chunks,
133+
mtmd_input_chunks & chunks,
114134
llama_pos pos0,
115135
llama_seq_id seq_id,
116136
int32_t n_batch);
@@ -132,11 +152,6 @@ struct mtmd_context_deleter {
132152
};
133153
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
134154

135-
struct mtmd_input_chunks_deleter {
136-
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
137-
};
138-
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
139-
140155
#else
141156

142157
static_assert(false && "C header is not yet supported by this library");

0 commit comments

Comments
 (0)