Skip to content

Commit fc1c860

Browse files
authored
Merge branch 'prepare-PR-of-minicpm-v2.6' into master
2 parents 911b437 + ea0c828 commit fc1c860

File tree

8 files changed

+1967
-23
lines changed

8 files changed

+1967
-23
lines changed

examples/llava/clip.cpp

Lines changed: 135 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// I'll gradually clean and extend it
44
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
55
#include "clip.h"
6+
#include "common.h"
67
#include "log.h"
78
#include "ggml.h"
89
#include "ggml-alloc.h"
@@ -81,6 +82,7 @@ static std::string format(const char * fmt, ...) {
8182
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
8283
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
8384
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
85+
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
8486
#define KEY_USE_GELU "clip.use_gelu"
8587
#define KEY_N_EMBD "clip.%s.embedding_length"
8688
#define KEY_N_FF "clip.%s.feed_forward_length"
@@ -526,6 +528,7 @@ struct clip_ctx {
526528
bool has_vision_encoder = false;
527529
bool has_llava_projector = false;
528530
bool has_minicpmv_projector = false;
531+
int minicpmv_version = 2;
529532

530533
struct clip_vision_model vision_model;
531534
projector_type proj_type = PROJECTOR_TYPE_MLP;
@@ -641,7 +644,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
641644
if (ctx->has_minicpmv_projector) {
642645
int pos_w = image_size_width/patch_size;
643646
int pos_h = image_size_height/patch_size;
644-
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
647+
if (ctx->minicpmv_version == 2) {
648+
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
649+
}
650+
else if (ctx->minicpmv_version == 3) {
651+
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
652+
}
645653
ggml_set_name(pos_embed, "pos_embed");
646654
ggml_set_input(pos_embed);
647655
}
@@ -769,7 +777,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
769777
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
770778
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
771779

772-
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
780+
}
781+
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
773782
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
774783
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
775784
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@@ -987,6 +996,72 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
987996
GGML_ASSERT(false);
988997
}
989998
}
999+
// minicpmv projector
1000+
else if (ctx->has_minicpmv_projector) {
1001+
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
1002+
struct ggml_tensor * q = model.mm_model_query;
1003+
{ // layernorm
1004+
q = ggml_norm(ctx0, q, eps);
1005+
q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
1006+
}
1007+
struct ggml_tensor *k, *v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
1008+
{ // layernorm
1009+
v = ggml_norm(ctx0, v, eps);
1010+
v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
1011+
}
1012+
{ // position
1013+
// q = ggml_add(ctx0, q, model.mm_model_pos_embed);
1014+
k = ggml_add(ctx0, v, pos_embed);
1015+
}
1016+
1017+
{ // attention
1018+
int hidden_size = 4096;
1019+
const int d_head = 128;
1020+
int n_head = hidden_size/d_head;
1021+
int num_query = 96;
1022+
if (ctx->minicpmv_version == 2) {
1023+
hidden_size = 4096;
1024+
n_head = hidden_size/d_head;
1025+
num_query = 96;
1026+
}
1027+
else if (ctx->minicpmv_version == 3) {
1028+
hidden_size = 3584;
1029+
n_head = hidden_size/d_head;
1030+
num_query = 64;
1031+
}
1032+
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
1033+
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
1034+
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
1035+
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
1036+
// permute
1037+
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
1038+
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
1039+
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
1040+
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
1041+
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
1042+
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
1043+
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
1044+
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
1045+
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
1046+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1047+
KQ = ggml_soft_max_inplace(ctx0, KQ);
1048+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
1049+
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
1050+
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1051+
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
1052+
1053+
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
1054+
}
1055+
{ // layernorm
1056+
embeddings = ggml_norm(ctx0, embeddings, eps);
1057+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
1058+
}
1059+
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
1060+
}
1061+
else {
1062+
GGML_ASSERT(false);
1063+
}
1064+
}
9901065

9911066
// build the graph
9921067
ggml_build_forward_expand(gf, embeddings);
@@ -1149,6 +1224,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
11491224
new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
11501225
}
11511226

1227+
idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
1228+
if (idx != -1) {
1229+
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
1230+
}
1231+
11521232
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
11531233

11541234
GGML_ASSERT(new_clip->has_vision_encoder);
@@ -1910,10 +1990,18 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
19101990
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
19111991
// res_imgs memory is being allocated here, previous allocations will be freed if found
19121992
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
1913-
if (clip_is_minicpmv(ctx)) {
1914-
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img);
1993+
1994+
if(clip_is_minicpmv(ctx)){
1995+
int max_slice_nums = 9;
1996+
if (ctx->minicpmv_version == 2) {
1997+
max_slice_nums = 9;
1998+
}
1999+
else if (ctx->minicpmv_version == 3) {
2000+
max_slice_nums = 9;
2001+
}
2002+
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
19152003
res_imgs->size = 0;
1916-
for (size_t i = 0; i < imgs.size(); ++i) {
2004+
for (size_t i = 0; i < imgs.size(); ++i){
19172005
res_imgs->size += imgs[i].size();
19182006
}
19192007
res_imgs->data = new clip_image_f32[res_imgs->size];
@@ -2146,7 +2234,12 @@ int clip_n_patches(const struct clip_ctx * ctx) {
21462234
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
21472235
n_patches /= 4;
21482236
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2149-
n_patches = 96;
2237+
if (ctx->minicpmv_version == 2) {
2238+
n_patches = 96;
2239+
}
2240+
else if (ctx->minicpmv_version == 3) {
2241+
n_patches = 64;
2242+
}
21502243
}
21512244

21522245
return n_patches;
@@ -2282,6 +2375,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
22822375
const int patch_size = hparams.patch_size;
22832376
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
22842377
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
2378+
if(ctx->load_image_size==nullptr){
2379+
ctx->load_image_size= clip_image_size_init();
2380+
}
2381+
const int pos_w = ctx->load_image_size->width/patch_size;
2382+
const int pos_h = ctx->load_image_size->height/patch_size;
22852383

22862384
{
22872385
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
@@ -2316,8 +2414,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23162414
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
23172415
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
23182416
int* positions_data = (int*)malloc(ggml_nbytes(positions));
2319-
for (int i = 0; i < num_positions; i++) {
2320-
positions_data[i] = std::floor(70.0*i/num_positions);
2417+
int bucket_coords_h[70];
2418+
int bucket_coords_w[70];
2419+
for (int i = 0; i < pos_h; i++){
2420+
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
2421+
}
2422+
for (int i = 0; i < pos_w; i++){
2423+
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
2424+
}
2425+
for (int i = 0, id = 0; i < pos_h; i++){
2426+
for (int j = 0; j < pos_w; j++){
2427+
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
2428+
}
23212429
}
23222430
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
23232431
free(positions_data);
@@ -2328,12 +2436,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23282436
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
23292437
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
23302438
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
2331-
if(ctx->load_image_size==nullptr){
2332-
ctx->load_image_size= clip_image_size_init();
2333-
}
2334-
int pos_w = ctx->load_image_size->width/patch_size;
2335-
int pos_h = ctx->load_image_size->height/patch_size;
23362439
int embed_dim = 4096;
2440+
if (ctx->minicpmv_version == 2) {
2441+
embed_dim = 4096;
2442+
}
2443+
else if (ctx->minicpmv_version == 3) {
2444+
embed_dim = 3584;
2445+
}
23372446
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
23382447

23392448
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2346,7 +2455,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
23462455
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
23472456
free(pos_embed_data);
23482457
}
2349-
} else {
2458+
}
2459+
else{
23502460
{
23512461
if (ctx->has_class_embedding) {
23522462
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
@@ -2548,13 +2658,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
25482658
return ctx->vision_model.mm_3_b->ne[0];
25492659
}
25502660
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2551-
return 4096;
2661+
if (ctx->minicpmv_version == 2) {
2662+
return 4096;
2663+
}
2664+
else if (ctx->minicpmv_version == 3) {
2665+
return 3584;
2666+
}
25522667
}
25532668

25542669
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
25552670
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
25562671
}
25572672

2558-
bool clip_is_minicpmv(const struct clip_ctx * ctx) {
2559-
return ctx->has_minicpmv_projector;
2673+
int clip_is_minicpmv(const struct clip_ctx * ctx) {
2674+
if (ctx->has_minicpmv_projector) {
2675+
return ctx->minicpmv_version;
2676+
}
2677+
return 0;
25602678
}

examples/llava/clip.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
8585

8686
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
8787

88-
CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx);
88+
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
8989

9090
#ifdef __cplusplus
9191
}

examples/llava/llava.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,16 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
254254
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip));
255255
int patch_size=14;
256256
load_image_size->width = img_res_v.data[i].nx;
257-
load_image_size->height = img_res_v.data[i].ny;
257+
load_image_size->height = img_res_v.data[i].ny;
258258
clip_add_load_image_size(ctx_clip, load_image_size);
259-
const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
259+
bool encoded = false;
260+
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
261+
if (has_minicpmv_projector == 2) {
262+
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
263+
}
264+
else if (has_minicpmv_projector == 3) {
265+
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
266+
}
260267
if (!encoded) {
261268
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
262269
return false;

examples/llava/minicpmv-cli.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
134134
std::string system_prompt;
135135
int idx = 0;
136136
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
137-
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
137+
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
138+
if (has_minicpmv_projector == 2) {
139+
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
140+
}
141+
else if (has_minicpmv_projector == 3) {
142+
system_prompt = "<|im_start|>user\n";
143+
}
138144
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
139145
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
140146
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
210216

211217
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
212218
std::string user_prompt = prompt;
213-
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
219+
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
220+
if (!is_first) {
221+
if (has_minicpmv_projector == 2) {
222+
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
223+
}
224+
else if (has_minicpmv_projector == 3) {
225+
user_prompt = "<|im_start|>user\n" + prompt;
226+
}
227+
}
214228

215229
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
216-
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
230+
if (has_minicpmv_projector == 2) {
231+
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
232+
}
233+
else if (has_minicpmv_projector == 3) {
234+
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
235+
}
236+
217237
// generate the response
218238

219239
LOG_TEE("\n");

0 commit comments

Comments
 (0)