Skip to content

mtmd : add methods to access mtmd_image_tokens #12906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llava/gemma3-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct gemma3_context {
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
/* use_gpu */ true,
/* timings */ true,
/* hash */ false,
/* n_threads */ params.cpuparams.n_threads,
/* verbosity */ GGML_LOG_LEVEL_INFO,
}));
Expand Down
68 changes: 62 additions & 6 deletions examples/llava/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@ struct mtmd_context {
struct clip_ctx * ctx_clip;
const struct llama_model * text_model;
std::vector<float> image_embd_v; // image embedding vector

bool print_timings;
int n_threads;
std::string image_marker;
bool calc_image_hash;

// TODO @ngxson : add timings

mtmd_context(const char * mmproj_fname,
const llama_model * text_model,
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
const mtmd_context_params & ctx_params) :
print_timings (ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker),
calc_image_hash(ctx_params.calc_image_hash)
{
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
Expand All @@ -49,6 +56,7 @@ struct mtmd_image_tokens {
uint32_t ny; // number of tokens in y direction
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
size_t image_hash = 0; // hash of the image, useful for KV cache tracking
};

mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
Expand Down Expand Up @@ -88,6 +96,16 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
return result;
}

static uint64_t hash_vector_float(const std::vector<float> & vec) {
uint64_t seed = vec.size();
std::hash<float> hasher;
for (float val : vec) {
// inspired by boost::hash_combine
seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}

mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps) {
Expand Down Expand Up @@ -153,6 +171,11 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
image_tokens->ny = 1; // TODO
image_tokens->batch_f32 = std::move(batch_f32);

// optionally calculate the hash
if (ctx->calc_image_hash) {
image_tokens->image_hash = hash_vector_float(image_tokens->batch_f32.entries[0]->buf);
}

mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
Expand All @@ -166,15 +189,40 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
return output;
}

void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
for (auto & chunk : *chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
delete chunk.tokens_image;
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
}
}

void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) {
if (free_images) {
for (auto & chunk : *chunks) {
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
mtmd_image_tokens_free(chunk.tokens_image);
chunk.tokens_image = nullptr;
}
}
}
delete chunks;
}

size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
return image_tokens->n_tokens();
}

size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
return image_tokens->nx;
}

size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
return image_tokens->ny;
}

uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens) {
return image_tokens->image_hash;
}

int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
Expand Down Expand Up @@ -289,7 +337,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
}

int32_t n_tokens = chunk.tokens_image->n_tokens();
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image);
float * embd = mtmd_get_output_embd(ctx);
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
int64_t t1 = ggml_time_ms();
Expand Down Expand Up @@ -339,3 +387,11 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
return 0;
}

bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
return true;
}
return false;
}
27 changes: 24 additions & 3 deletions examples/llava/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
struct mtmd_context_params {
bool use_gpu = true;
bool print_timings = true;
// calc_image_hash is useful for tracking KV cache
// if not set, mtmd_image_tokens_get_hash will return 0
bool calc_image_hash = false;
int n_threads = 4;
enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
const char * image_marker = "<__image__>";
Expand Down Expand Up @@ -81,13 +84,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
// 2. (image tokens)
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// the returned value must be freed using mtmd_input_chunks_free()
// this function is thread-safe (shared ctx)
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text & text,
const std::vector<mtmd_bitmap> & bitmaps);

// free image chunk data
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free()
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images);

// access mtmd_image_tokens
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
MTMD_API uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens);
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);

// returns 0 on success
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
Expand All @@ -96,6 +107,11 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);

// whether we need to set non-causal mask before llama_decode
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);



//
// helper functions (can be implemented based on other functions)
//
Expand Down Expand Up @@ -133,10 +149,15 @@ struct mtmd_context_deleter {
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;

struct mtmd_input_chunks_deleter {
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); }
};
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;

struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;

#else

static_assert(false && "C header is not yet supported by this library");
Expand Down
Loading