Skip to content

Commit 5dad774

Browse files
committed
precompute token maps, very slightly hacky
1 parent 1c7996a commit 5dad774

File tree

3 files changed

+115
-74
lines changed

3 files changed

+115
-74
lines changed

src/llama-vocab.cpp

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -46,49 +46,6 @@ static std::string format(const char * fmt, ...) {
4646
return std::string(buf.data(), size);
4747
}
4848

49-
struct naive_trie {
50-
naive_trie() : has_value(false), value(0) {
51-
}
52-
void insert(const char * key, size_t len, int32_t value = 0) {
53-
if (len == 0) {
54-
this->has_value = true;
55-
this->value = value;
56-
return;
57-
}
58-
char c = key[0];
59-
auto res = children.find(c);
60-
if (res != children.end()) {
61-
res->second.insert(key + 1, len - 1, value);
62-
} else {
63-
auto res = children.insert(std::make_pair(c, naive_trie()));
64-
res.first->second.insert(key + 1, len - 1, value);
65-
}
66-
}
67-
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
68-
if (len == 0 || offset == len) {
69-
return std::make_pair(key, offset);
70-
}
71-
char c = key[offset];
72-
auto res = children.find(c);
73-
if (res != children.end()) {
74-
return res->second.get_longest_prefix(key, len, offset + 1);
75-
} else {
76-
return std::make_pair(key, offset);
77-
}
78-
}
79-
struct naive_trie * traverse(const char c) {
80-
auto res = children.find(c);
81-
if (res != children.end()) {
82-
return &res->second;
83-
} else {
84-
return NULL;
85-
}
86-
}
87-
std::map<char, struct naive_trie> children;
88-
bool has_value;
89-
llama_token value;
90-
};
91-
9249
//
9350
// impl
9451
//
@@ -779,27 +736,6 @@ struct llm_tokenizer_ugm {
779736
prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
780737
prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
781738
}
782-
783-
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
784-
const auto &token_data = vocab.id_to_token[id];
785-
786-
if (llama_is_normal_token(vocab, id)) {
787-
min_score = std::min<float>(min_score, token_data.score);
788-
max_score = std::max<float>(max_score, token_data.score);
789-
}
790-
791-
if (llama_is_normal_token(vocab, id) ||
792-
llama_is_user_defined_token(vocab, id) ||
793-
llama_is_unused_token(vocab, id)) {
794-
token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
795-
}
796-
797-
if (llama_is_user_defined_token(vocab, id)) {
798-
user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
799-
}
800-
}
801-
802-
unknown_token_score = min_score - unknown_token_score_penalty;
803739
}
804740

805741
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
@@ -840,7 +776,7 @@ struct llm_tokenizer_ugm {
840776
// traverse the token matcher trie to find a matching token
841777
bool single_codepoint_token_found = false;
842778
const struct best_tokenization & current_best = tokenization_results[input_offset];
843-
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
779+
const struct naive_trie * node = vocab.token_matcher.traverse(normalized[prefix_offset++]);
844780

845781
while (prefix_offset <= input_len && node != NULL) {
846782
// check if we found valid token in prefix
@@ -1003,7 +939,7 @@ struct llm_tokenizer_ugm {
1003939
}
1004940

1005941
// if input prefix matches some user-defined token return this token as normalization result
1006-
auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
942+
auto user_defined_token_match = vocab.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
1007943
if (user_defined_token_match.second > 0) {
1008944
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
1009945
}
@@ -1076,22 +1012,14 @@ struct llm_tokenizer_ugm {
10761012
const uint32_t * xcda_array = NULL;
10771013
size_t xcda_array_size = 0;
10781014

1079-
struct naive_trie user_defined_token_matcher;
1080-
10811015
// this structure stores the best tokenization so far at input_offset
10821016
struct best_tokenization {
10831017
llama_token token_id;
10841018
size_t input_offset;
10851019
float score_sum;
10861020
};
10871021

1088-
float min_score = FLT_MAX;
1089-
float max_score = -FLT_MAX;
1090-
1091-
float unknown_token_score_penalty = 10.0;
10921022
float unknown_token_score;
1093-
1094-
struct naive_trie token_matcher;
10951023
};
10961024

10971025
//

src/llama-vocab.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,58 @@
77
#include <unordered_map>
88
#include <map>
99

10+
11+
//
12+
// naive_trie
13+
//
14+
15+
struct naive_trie {
16+
naive_trie() : has_value(false), value(0) {
17+
}
18+
void insert(const char * key, size_t len, int32_t value = 0) {
19+
if (len == 0) {
20+
this->has_value = true;
21+
this->value = value;
22+
return;
23+
}
24+
char c = key[0];
25+
auto res = children.find(c);
26+
if (res != children.end()) {
27+
res->second.insert(key + 1, len - 1, value);
28+
} else {
29+
auto res = children.insert(std::make_pair(c, naive_trie()));
30+
res.first->second.insert(key + 1, len - 1, value);
31+
}
32+
}
33+
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
34+
if (len == 0 || offset == len) {
35+
return std::make_pair(key, offset);
36+
}
37+
char c = key[offset];
38+
auto res = children.find(c);
39+
if (res != children.end()) {
40+
return res->second.get_longest_prefix(key, len, offset + 1);
41+
} else {
42+
return std::make_pair(key, offset);
43+
}
44+
}
45+
const struct naive_trie * traverse(const char c) const {
46+
auto res = children.find(c);
47+
if (res != children.end()) {
48+
return &res->second;
49+
} else {
50+
return NULL;
51+
}
52+
}
53+
std::map<char, struct naive_trie> children;
54+
bool has_value;
55+
llama_token value;
56+
};
57+
58+
//
59+
// llama_vocab
60+
//
61+
1062
struct llama_vocab {
1163
using id = llama_token;
1264
using token = std::string;
@@ -57,6 +109,9 @@ struct llama_vocab {
57109
bool tokenizer_treat_whitespace_as_suffix = false;
58110

59111
std::vector<char> precompiled_charsmap;
112+
struct naive_trie user_defined_token_matcher;
113+
struct naive_trie token_matcher;
114+
float unknown_token_score = 0.0f;
60115

61116
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
62117
};

src/llama.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5217,6 +5217,36 @@ static void llm_load_hparams(
52175217
hparams.rope_type = llama_rope_type(&model);
52185218
}
52195219

5220+
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
5221+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5222+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
5223+
}
5224+
5225+
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
5226+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5227+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
5228+
}
5229+
5230+
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
5231+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5232+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
5233+
}
5234+
5235+
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
5236+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5237+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
5238+
}
5239+
5240+
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
5241+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5242+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
5243+
}
5244+
5245+
static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
5246+
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
5247+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
5248+
}
5249+
52205250
static void llm_load_vocab(
52215251
llama_model_loader & ml,
52225252
llama_model & model) {
@@ -5598,6 +5628,34 @@ static void llm_load_vocab(
55985628
}
55995629
}
56005630

5631+
// parse precompiled charsmap
5632+
if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
5633+
float min_score = -FLT_MIN;
5634+
float max_score = FLT_MAX;
5635+
5636+
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
5637+
const auto &token_data = vocab.id_to_token[id];
5638+
5639+
if (llama_is_normal_token(vocab, id)) {
5640+
min_score = std::min<float>(min_score, token_data.score);
5641+
max_score = std::max<float>(max_score, token_data.score);
5642+
}
5643+
5644+
if (llama_is_normal_token(vocab, id) ||
5645+
llama_is_user_defined_token(vocab, id) ||
5646+
llama_is_unused_token(vocab, id)) {
5647+
vocab.token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
5648+
}
5649+
5650+
if (llama_is_user_defined_token(vocab, id)) {
5651+
vocab.user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
5652+
}
5653+
}
5654+
5655+
float unknown_token_score_penalty = 10.0;
5656+
vocab.unknown_token_score = min_score - unknown_token_score_penalty;
5657+
}
5658+
56015659
// Handle add_bos_token and add_eos_token
56025660
{
56035661
bool temp = true;

0 commit comments

Comments
 (0)