Skip to content

Commit 87daf40

Browse files
author
pockers21
committed
convert: add eagle2 draft arch
1 parent dd8ba93 commit 87daf40

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27112711
return []
27122712
yield from super().modify_tensors(data_torch, name, bid)
27132713

2714+
@ModelBase.register("Eagle2DraftForCausalLM")
2715+
class Eagle2DraftModel(TextModel):
2716+
model_arch = gguf.MODEL_ARCH.EAGLE2_DRAFT
2717+
2718+
def set_vocab(self):
2719+
try:
2720+
self._set_vocab_sentencepiece()
2721+
except FileNotFoundError:
2722+
self._set_vocab_gpt2()
2723+
2724+
def set_gguf_parameters(self):
2725+
super().set_gguf_parameters()
2726+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
2727+
if self.hparams["rope_scaling"].get("type") == "yarn":
2728+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2729+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
2730+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
27142731

27152732
@ModelBase.register(
27162733
"Qwen2VLModel",

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ class MODEL_ARCH(IntEnum):
300300
QWEN2VL = auto()
301301
QWEN3 = auto()
302302
QWEN3MOE = auto()
303+
EAGLE2_DRAFT = auto()
303304
PHI2 = auto()
304305
PHI3 = auto()
305306
PHIMOE = auto()
@@ -360,6 +361,7 @@ class MODEL_TENSOR(IntEnum):
360361
TOKEN_EMBD_NORM = auto()
361362
TOKEN_TYPES = auto()
362363
POS_EMBD = auto()
364+
FC = auto()
363365
OUTPUT = auto()
364366
OUTPUT_NORM = auto()
365367
ROPE_FREQS = auto()
@@ -580,6 +582,7 @@ class MODEL_TENSOR(IntEnum):
580582
MODEL_ARCH.QWEN2VL: "qwen2vl",
581583
MODEL_ARCH.QWEN3: "qwen3",
582584
MODEL_ARCH.QWEN3MOE: "qwen3moe",
585+
MODEL_ARCH.EAGLE2_DRAFT: "eagle2-draft",
583586
MODEL_ARCH.PHI2: "phi2",
584587
MODEL_ARCH.PHI3: "phi3",
585588
MODEL_ARCH.PHIMOE: "phimoe",
@@ -640,6 +643,7 @@ class MODEL_TENSOR(IntEnum):
640643
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
641644
MODEL_TENSOR.TOKEN_TYPES: "token_types",
642645
MODEL_TENSOR.POS_EMBD: "position_embd",
646+
MODEL_TENSOR.FC: "fc",
643647
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
644648
MODEL_TENSOR.OUTPUT: "output",
645649
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
@@ -1207,6 +1211,21 @@ class MODEL_TENSOR(IntEnum):
12071211
MODEL_TENSOR.FFN_DOWN,
12081212
MODEL_TENSOR.FFN_UP,
12091213
],
1214+
MODEL_ARCH.EAGLE2_DRAFT: [
1215+
MODEL_TENSOR.TOKEN_EMBD,
1216+
MODEL_TENSOR.FC,
1217+
MODEL_TENSOR.OUTPUT,
1218+
MODEL_TENSOR.ATTN_NORM,
1219+
MODEL_TENSOR.ATTN_Q,
1220+
MODEL_TENSOR.ATTN_K,
1221+
MODEL_TENSOR.ATTN_V,
1222+
MODEL_TENSOR.ATTN_OUT,
1223+
MODEL_TENSOR.FFN_NORM,
1224+
MODEL_TENSOR.FFN_GATE,
1225+
MODEL_TENSOR.FFN_DOWN,
1226+
MODEL_TENSOR.FFN_UP,
1227+
1228+
],
12101229
MODEL_ARCH.QWEN2MOE: [
12111230
MODEL_TENSOR.TOKEN_EMBD,
12121231
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class TensorNameMap:
5858
"wpe", # gpt2
5959
),
6060

61+
#eagle2 draft model
62+
MODEL_TENSOR.FC: (
63+
"model.fc",
64+
),
6165
# Output
6266
MODEL_TENSOR.OUTPUT: (
6367
"embed_out", # gptneox

0 commit comments

Comments
 (0)