Skip to content

Commit 5f1d53f

Browse files
pacman100ggerganov
authored andcommitted
llama : add StarCoder2 support (ggml-org#5795)
* Add support for starcoder2 * handle rope type * skip rope freq and rotary embeddings from being serialized * resolve comments * Update llama.cpp * remove redundant changes * handle `rope-theta` * llama : change starcoder2 rope type * address comment --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 9a7a5a2 commit 5f1d53f

File tree

4 files changed

+229
-1
lines changed

4 files changed

+229
-1
lines changed

convert-hf-to-gguf.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def set_gguf_parameters(self):
9696
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
9797
self.gguf_writer.add_head_count_kv(n_head_kv)
9898

99+
if (rope_theta := self.hparams.get("rope_theta")) is not None:
100+
self.gguf_writer.add_rope_freq_base(rope_theta)
99101
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
100102
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
101-
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None:
103+
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
102104
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
103105
if (n_experts := self.hparams.get("num_local_experts")) is not None:
104106
self.gguf_writer.add_expert_count(n_experts)
@@ -220,6 +222,8 @@ def from_model_architecture(model_architecture):
220222
return NomicBertModel
221223
if model_architecture == "GemmaForCausalLM":
222224
return GemmaModel
225+
if model_architecture == "Starcoder2ForCausalLM":
226+
return Model
223227
return Model
224228

225229
def _is_model_safetensors(self) -> bool:
@@ -281,6 +285,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
281285
return gguf.MODEL_ARCH.NOMIC_BERT
282286
if arch == "GemmaForCausalLM":
283287
return gguf.MODEL_ARCH.GEMMA
288+
if arch == "Starcoder2ForCausalLM":
289+
return gguf.MODEL_ARCH.STARCODER2
284290

285291
raise NotImplementedError(f'Architecture "{arch}" not supported!')
286292

gguf-py/gguf/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class MODEL_ARCH(IntEnum):
112112
INTERNLM2 = auto()
113113
MINICPM = auto()
114114
GEMMA = auto()
115+
STARCODER2 = auto()
115116

116117

117118
class MODEL_TENSOR(IntEnum):
@@ -169,6 +170,7 @@ class MODEL_TENSOR(IntEnum):
169170
MODEL_ARCH.INTERNLM2: "internlm2",
170171
MODEL_ARCH.MINICPM: "minicpm",
171172
MODEL_ARCH.GEMMA: "gemma",
173+
MODEL_ARCH.STARCODER2: "starcoder2",
172174
}
173175

174176
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -526,6 +528,21 @@ class MODEL_TENSOR(IntEnum):
526528
MODEL_TENSOR.FFN_UP,
527529
MODEL_TENSOR.FFN_NORM,
528530
],
531+
MODEL_ARCH.STARCODER2: [
532+
MODEL_TENSOR.TOKEN_EMBD,
533+
MODEL_TENSOR.OUTPUT_NORM,
534+
MODEL_TENSOR.OUTPUT,
535+
MODEL_TENSOR.ROPE_FREQS,
536+
MODEL_TENSOR.ATTN_NORM,
537+
MODEL_TENSOR.ATTN_Q,
538+
MODEL_TENSOR.ATTN_K,
539+
MODEL_TENSOR.ATTN_V,
540+
MODEL_TENSOR.ATTN_OUT,
541+
MODEL_TENSOR.ATTN_ROT_EMBD,
542+
MODEL_TENSOR.FFN_NORM,
543+
MODEL_TENSOR.FFN_DOWN,
544+
MODEL_TENSOR.FFN_UP,
545+
],
529546
# TODO
530547
}
531548

@@ -554,6 +571,10 @@ class MODEL_TENSOR(IntEnum):
554571
MODEL_TENSOR.ROPE_FREQS,
555572
MODEL_TENSOR.ATTN_ROT_EMBD,
556573
],
574+
MODEL_ARCH.STARCODER2: [
575+
MODEL_TENSOR.ROPE_FREQS,
576+
MODEL_TENSOR.ATTN_ROT_EMBD,
577+
],
557578
}
558579

559580
#

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class TensorNameMap:
210210
"model.layers.layers.{bid}.mlp.up_proj", # plamo
211211
"model.layers.{bid}.feed_forward.w3", # internlm2
212212
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
213+
"model.layers.{bid}.mlp.c_fc", # starcoder2
213214
),
214215

215216
MODEL_TENSOR.FFN_UP_EXP: (
@@ -256,6 +257,7 @@ class TensorNameMap:
256257
"model.layers.layers.{bid}.mlp.down_proj", # plamo
257258
"model.layers.{bid}.feed_forward.w2", # internlm2
258259
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
260+
"model.layers.{bid}.mlp.c_proj", # starcoder2
259261
),
260262

261263
MODEL_TENSOR.FFN_DOWN_EXP: (

llama.cpp

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ enum llm_arch {
211211
LLM_ARCH_INTERNLM2,
212212
LLM_ARCH_MINICPM,
213213
LLM_ARCH_GEMMA,
214+
LLM_ARCH_STARCODER2,
214215
LLM_ARCH_UNKNOWN,
215216
};
216217

@@ -238,6 +239,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
238239
{ LLM_ARCH_INTERNLM2, "internlm2" },
239240
{ LLM_ARCH_MINICPM, "minicpm" },
240241
{ LLM_ARCH_GEMMA, "gemma" },
242+
{ LLM_ARCH_STARCODER2, "starcoder2" },
241243
};
242244

243245
enum llm_kv {
@@ -779,6 +781,24 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
779781
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
780782
},
781783
},
784+
{
785+
LLM_ARCH_STARCODER2,
786+
{
787+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
788+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
789+
{ LLM_TENSOR_OUTPUT, "output" },
790+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
791+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
792+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
793+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
794+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
795+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
796+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
797+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
798+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
799+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
800+
},
801+
},
782802
{
783803
LLM_ARCH_UNKNOWN,
784804
{
@@ -3320,6 +3340,16 @@ static void llm_load_hparams(
33203340
default: model.type = e_model::MODEL_UNKNOWN;
33213341
}
33223342
} break;
3343+
case LLM_ARCH_STARCODER2:
3344+
{
3345+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
3346+
switch (hparams.n_layer) {
3347+
case 30: model.type = e_model::MODEL_3B; break;
3348+
case 32: model.type = e_model::MODEL_7B; break;
3349+
case 40: model.type = e_model::MODEL_15B; break;
3350+
default: model.type = e_model::MODEL_UNKNOWN;
3351+
}
3352+
} break;
33233353
default: (void)0;
33243354
}
33253355

@@ -4490,6 +4520,56 @@ static bool llm_load_tensors(
44904520
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
44914521
}
44924522
} break;
4523+
case LLM_ARCH_STARCODER2:
4524+
{
4525+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4526+
4527+
// output
4528+
{
4529+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4530+
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
4531+
4532+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4533+
// if output is NULL, init from the input tok embed
4534+
if (model.output == NULL) {
4535+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4536+
ml.n_created--; // artificial tensor
4537+
ml.size_data += ggml_nbytes(model.output);
4538+
}
4539+
4540+
}
4541+
4542+
for (int i = 0; i < n_layer; ++i) {
4543+
ggml_context * ctx_layer = ctx_for_layer(i);
4544+
ggml_context * ctx_split = ctx_for_layer_split(i);
4545+
4546+
auto & layer = model.layers[i];
4547+
4548+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4549+
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
4550+
4551+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4552+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
4553+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
4554+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4555+
4556+
// optional bias tensors
4557+
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
4558+
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
4559+
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
4560+
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
4561+
4562+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
4563+
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
4564+
4565+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
4566+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4567+
4568+
// optional bias tensors
4569+
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
4570+
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff});
4571+
}
4572+
} break;
44934573
default:
44944574
throw std::runtime_error("unknown architecture");
44954575
}
@@ -7559,6 +7639,120 @@ struct llm_build_context {
75597639

75607640
return gf;
75617641
}
7642+
7643+
struct ggml_cgraph * build_starcoder2() {
7644+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7645+
7646+
const int64_t n_embd_head = hparams.n_embd_head_v;
7647+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7648+
GGML_ASSERT(n_embd_head == hparams.n_rot);
7649+
7650+
struct ggml_tensor * cur;
7651+
struct ggml_tensor * inpL;
7652+
7653+
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
7654+
cb(inpL, "inp_embd", -1);
7655+
7656+
// inp_pos - contains the positions
7657+
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
7658+
cb(inp_pos, "inp_pos", -1);
7659+
7660+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7661+
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
7662+
cb(KQ_mask, "KQ_mask", -1);
7663+
7664+
for (int il = 0; il < n_layer; ++il) {
7665+
struct ggml_tensor * inpSA = inpL;
7666+
7667+
// norm
7668+
cur = llm_build_norm(ctx0, inpL, hparams,
7669+
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
7670+
LLM_NORM, cb, il);
7671+
cb(cur, "attn_norm", il);
7672+
7673+
// self-attention
7674+
{
7675+
// compute Q and K and RoPE them
7676+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
7677+
cb(Qcur, "Qcur", il);
7678+
if (model.layers[il].bq) {
7679+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7680+
cb(Qcur, "Qcur", il);
7681+
}
7682+
7683+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
7684+
cb(Kcur, "Kcur", il);
7685+
if (model.layers[il].bk) {
7686+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7687+
cb(Kcur, "Kcur", il);
7688+
}
7689+
7690+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
7691+
cb(Vcur, "Vcur", il);
7692+
if (model.layers[il].bv) {
7693+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7694+
cb(Vcur, "Vcur", il);
7695+
}
7696+
7697+
Qcur = ggml_rope_custom(
7698+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
7699+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7700+
ext_factor, attn_factor, beta_fast, beta_slow
7701+
);
7702+
cb(Qcur, "Qcur", il);
7703+
7704+
Kcur = ggml_rope_custom(
7705+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7706+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7707+
ext_factor, attn_factor, beta_fast, beta_slow
7708+
);
7709+
cb(Kcur, "Kcur", il);
7710+
7711+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7712+
model.layers[il].wo, model.layers[il].bo,
7713+
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
7714+
cb(cur, "kqv_out", il);
7715+
}
7716+
7717+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7718+
cb(ffn_inp, "ffn_inp", il);
7719+
7720+
// feed-forward network
7721+
7722+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
7723+
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
7724+
LLM_NORM, cb, il);
7725+
cb(cur, "ffn_norm", il);
7726+
7727+
cur = llm_build_ffn(ctx0, cur,
7728+
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
7729+
NULL, NULL,
7730+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
7731+
NULL,
7732+
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
7733+
cb(cur, "ffn_out", il);
7734+
cur = ggml_add(ctx0, cur, ffn_inp);
7735+
cb(cur, "l_out", il);
7736+
7737+
// input for next layer
7738+
inpL = cur;
7739+
}
7740+
7741+
cur = inpL;
7742+
7743+
cur = llm_build_norm(ctx0, cur, hparams,
7744+
model.output_norm, model.output_norm_b,
7745+
LLM_NORM, cb, -1);
7746+
cb(cur, "result_norm", -1);
7747+
7748+
// lm_head
7749+
cur = ggml_mul_mat(ctx0, model.output, cur);
7750+
cb(cur, "result_output", -1);
7751+
7752+
ggml_build_forward_expand(gf, cur);
7753+
7754+
return gf;
7755+
}
75627756
};
75637757

75647758
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -7705,6 +7899,10 @@ static struct ggml_cgraph * llama_build_graph(
77057899
{
77067900
result = llm.build_gemma();
77077901
} break;
7902+
case LLM_ARCH_STARCODER2:
7903+
{
7904+
result = llm.build_starcoder2();
7905+
} break;
77087906
default:
77097907
GGML_ASSERT(false);
77107908
}
@@ -12084,6 +12282,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1208412282
case LLM_ARCH_QWEN2:
1208512283
case LLM_ARCH_PHI2:
1208612284
case LLM_ARCH_GEMMA:
12285+
case LLM_ARCH_STARCODER2:
1208712286
return LLAMA_ROPE_TYPE_NEOX;
1208812287

1208912288
// all model arches should be listed explicitly here

0 commit comments

Comments
 (0)