Skip to content

Commit 47068e5

Browse files
authored
speculative : PoC for speeding-up inference via speculative sampling (#2926)
* speculative : initial example * speculative : print encoding speed * speculative : add --draft CLI arg
1 parent 8f429fa commit 47068e5

File tree

6 files changed

+440
-115
lines changed

6 files changed

+440
-115
lines changed

common/common.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
305305
break;
306306
}
307307
params.n_keep = std::stoi(argv[i]);
308+
} else if (arg == "--draft") {
309+
if (++i >= argc) {
310+
invalid_param = true;
311+
break;
312+
}
313+
params.n_draft = std::stoi(argv[i]);
308314
} else if (arg == "--chunks") {
309315
if (++i >= argc) {
310316
invalid_param = true;
@@ -317,6 +323,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317323
break;
318324
}
319325
params.model = argv[i];
326+
} else if (arg == "-md" || arg == "--model-draft") {
327+
if (++i >= argc) {
328+
invalid_param = true;
329+
break;
330+
}
331+
params.model_draft = argv[i];
320332
} else if (arg == "-a" || arg == "--alias") {
321333
if (++i >= argc) {
322334
invalid_param = true;
@@ -638,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
638650
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
639651
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
640652
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
653+
fprintf(stdout, " --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
641654
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
642655
if (llama_mlock_supported()) {
643656
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
@@ -669,6 +682,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669682
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
670683
fprintf(stdout, " -m FNAME, --model FNAME\n");
671684
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
685+
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
686+
fprintf(stdout, " draft model for speculative decoding (default: %s)\n", params.model.c_str());
672687
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
673688
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
674689
fprintf(stdout, "\n");
@@ -832,6 +847,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
832847
return result;
833848
}
834849

850+
//
851+
// Sampling utils
852+
//
853+
854+
llama_token llama_sample_token(
855+
struct llama_context * ctx,
856+
struct llama_context * ctx_guidance,
857+
struct llama_grammar * grammar,
858+
const struct gpt_params & params,
859+
const std::vector<llama_token> & last_tokens,
860+
std::vector<llama_token_data> & candidates,
861+
int idx) {
862+
const int n_ctx = llama_n_ctx(ctx);
863+
const int n_vocab = llama_n_vocab(ctx);
864+
865+
const float temp = params.temp;
866+
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
867+
const float top_p = params.top_p;
868+
const float tfs_z = params.tfs_z;
869+
const float typical_p = params.typical_p;
870+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
871+
const float repeat_penalty = params.repeat_penalty;
872+
const float alpha_presence = params.presence_penalty;
873+
const float alpha_frequency = params.frequency_penalty;
874+
const int mirostat = params.mirostat;
875+
const float mirostat_tau = params.mirostat_tau;
876+
const float mirostat_eta = params.mirostat_eta;
877+
const bool penalize_nl = params.penalize_nl;
878+
879+
llama_token id = 0;
880+
881+
float * logits = llama_get_logits(ctx) + idx * n_vocab;
882+
883+
// Apply params.logit_bias map
884+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
885+
logits[it->first] += it->second;
886+
}
887+
888+
candidates.clear();
889+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
890+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
891+
}
892+
893+
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
894+
895+
if (ctx_guidance) {
896+
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
897+
}
898+
899+
// apply penalties
900+
if (!last_tokens.empty()) {
901+
const float nl_logit = logits[llama_token_nl(ctx)];
902+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
903+
904+
llama_sample_repetition_penalty(ctx, &cur_p,
905+
last_tokens.data() + last_tokens.size() - last_n_repeat,
906+
last_n_repeat, repeat_penalty);
907+
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
908+
last_tokens.data() + last_tokens.size() - last_n_repeat,
909+
last_n_repeat, alpha_frequency, alpha_presence);
910+
911+
if (!penalize_nl) {
912+
for (size_t idx = 0; idx < cur_p.size; idx++) {
913+
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
914+
cur_p.data[idx].logit = nl_logit;
915+
break;
916+
}
917+
}
918+
}
919+
}
920+
921+
if (grammar != NULL) {
922+
llama_sample_grammar(ctx, &cur_p, grammar);
923+
}
924+
925+
if (temp <= 0) {
926+
// Greedy sampling
927+
id = llama_sample_token_greedy(ctx, &cur_p);
928+
} else {
929+
if (mirostat == 1) {
930+
static float mirostat_mu = 2.0f * mirostat_tau;
931+
const int mirostat_m = 100;
932+
llama_sample_temperature(ctx, &cur_p, temp);
933+
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
934+
} else if (mirostat == 2) {
935+
static float mirostat_mu = 2.0f * mirostat_tau;
936+
llama_sample_temperature(ctx, &cur_p, temp);
937+
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
938+
} else {
939+
// Temperature sampling
940+
llama_sample_top_k (ctx, &cur_p, top_k, 1);
941+
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
942+
llama_sample_typical (ctx, &cur_p, typical_p, 1);
943+
llama_sample_top_p (ctx, &cur_p, top_p, 1);
944+
llama_sample_temperature(ctx, &cur_p, temp);
945+
946+
{
947+
const int n_top = 10;
948+
LOG("top %d candidates:\n", n_top);
949+
950+
for (int i = 0; i < n_top; i++) {
951+
const llama_token id = cur_p.data[i].id;
952+
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
953+
}
954+
}
955+
956+
id = llama_sample_token(ctx, &cur_p);
957+
958+
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
959+
}
960+
}
961+
// printf("`%d`", candidates_p.size);
962+
963+
if (grammar != NULL) {
964+
llama_grammar_accept_token(ctx, grammar, id);
965+
}
966+
967+
return id;
968+
}
969+
970+
//
971+
// YAML utils
972+
//
973+
835974
// returns true if successful, false otherwise
836975
bool create_directory_with_parents(const std::string & path) {
837976
#ifdef _WIN32
@@ -1070,6 +1209,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10701209
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
10711210
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
10721211
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
1212+
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
10731213
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
10741214
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
10751215
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);

common/common.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct gpt_params {
3232
int32_t n_ctx = 512; // context size
3333
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
3434
int32_t n_keep = 0; // number of tokens to keep from initial prompt
35+
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
3536
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
3637
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
3738
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
@@ -63,6 +64,7 @@ struct gpt_params {
6364
float cfg_scale = 1.f; // How strong is guidance
6465

6566
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
67+
std::string model_draft = ""; // draft model for speculative decoding
6668
std::string model_alias = "unknown"; // model alias
6769
std::string prompt = "";
6870
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
@@ -156,6 +158,40 @@ std::string llama_detokenize_bpe(
156158
llama_context * ctx,
157159
const std::vector<llama_token> & tokens);
158160

161+
//
162+
// Sampling utils
163+
//
164+
165+
// this is a common sampling function used across the examples for convenience
166+
// it can serve as a starting point for implementing your own sampling function
167+
//
168+
// required:
169+
// - ctx: context to use for sampling
170+
// - params: sampling parameters
171+
//
172+
// optional:
173+
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
174+
// - grammar: grammar to use for sampling, ignore if NULL
175+
// - last_tokens: needed for repetition penalty, ignore if empty
176+
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
177+
//
178+
// returns:
179+
// - token: sampled token
180+
// - candidates: vector of candidate tokens
181+
//
182+
llama_token llama_sample_token(
183+
struct llama_context * ctx,
184+
struct llama_context * ctx_guidance,
185+
struct llama_grammar * grammar,
186+
const struct gpt_params & params,
187+
const std::vector<llama_token> & last_tokens,
188+
std::vector<llama_token_data> & candidates,
189+
int idx = 0);
190+
191+
//
192+
// YAML utils
193+
//
194+
159195
bool create_directory_with_parents(const std::string & path);
160196
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
161197
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(speculative)
2627
add_subdirectory(embd-input)
2728
add_subdirectory(llama-bench)
2829
add_subdirectory(beam-search)

0 commit comments

Comments
 (0)