@@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
305
305
break ;
306
306
}
307
307
params.n_keep = std::stoi (argv[i]);
308
+ } else if (arg == " --draft" ) {
309
+ if (++i >= argc) {
310
+ invalid_param = true ;
311
+ break ;
312
+ }
313
+ params.n_draft = std::stoi (argv[i]);
308
314
} else if (arg == " --chunks" ) {
309
315
if (++i >= argc) {
310
316
invalid_param = true ;
@@ -317,6 +323,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317
323
break ;
318
324
}
319
325
params.model = argv[i];
326
+ } else if (arg == " -md" || arg == " --model-draft" ) {
327
+ if (++i >= argc) {
328
+ invalid_param = true ;
329
+ break ;
330
+ }
331
+ params.model_draft = argv[i];
320
332
} else if (arg == " -a" || arg == " --alias" ) {
321
333
if (++i >= argc) {
322
334
invalid_param = true ;
@@ -638,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
638
650
fprintf (stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n " );
639
651
fprintf (stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n " , params.hellaswag_tasks );
640
652
fprintf (stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
653
+ fprintf (stdout, " --draft N number of tokens to draft for speculative decoding (default: %d)\n " , params.n_draft );
641
654
fprintf (stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n " , params.n_chunks );
642
655
if (llama_mlock_supported ()) {
643
656
fprintf (stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n " );
@@ -669,6 +682,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669
682
fprintf (stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n " );
670
683
fprintf (stdout, " -m FNAME, --model FNAME\n " );
671
684
fprintf (stdout, " model path (default: %s)\n " , params.model .c_str ());
685
+ fprintf (stdout, " -md FNAME, --model-draft FNAME\n " );
686
+ fprintf (stdout, " draft model for speculative decoding (default: %s)\n " , params.model .c_str ());
672
687
fprintf (stdout, " -ld LOGDIR, --logdir LOGDIR\n " );
673
688
fprintf (stdout, " path under which to save YAML logs (no logging if unset)\n " );
674
689
fprintf (stdout, " \n " );
@@ -832,6 +847,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
832
847
return result;
833
848
}
834
849
850
+ //
851
+ // Sampling utils
852
+ //
853
+
854
+ llama_token llama_sample_token (
855
+ struct llama_context * ctx,
856
+ struct llama_context * ctx_guidance,
857
+ struct llama_grammar * grammar,
858
+ const struct gpt_params & params,
859
+ const std::vector<llama_token> & last_tokens,
860
+ std::vector<llama_token_data> & candidates,
861
+ int idx) {
862
+ const int n_ctx = llama_n_ctx (ctx);
863
+ const int n_vocab = llama_n_vocab (ctx);
864
+
865
+ const float temp = params.temp ;
866
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k ;
867
+ const float top_p = params.top_p ;
868
+ const float tfs_z = params.tfs_z ;
869
+ const float typical_p = params.typical_p ;
870
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
871
+ const float repeat_penalty = params.repeat_penalty ;
872
+ const float alpha_presence = params.presence_penalty ;
873
+ const float alpha_frequency = params.frequency_penalty ;
874
+ const int mirostat = params.mirostat ;
875
+ const float mirostat_tau = params.mirostat_tau ;
876
+ const float mirostat_eta = params.mirostat_eta ;
877
+ const bool penalize_nl = params.penalize_nl ;
878
+
879
+ llama_token id = 0 ;
880
+
881
+ float * logits = llama_get_logits (ctx) + idx * n_vocab;
882
+
883
+ // Apply params.logit_bias map
884
+ for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
885
+ logits[it->first ] += it->second ;
886
+ }
887
+
888
+ candidates.clear ();
889
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
890
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
891
+ }
892
+
893
+ llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
894
+
895
+ if (ctx_guidance) {
896
+ llama_sample_classifier_free_guidance (ctx, &cur_p, ctx_guidance, params.cfg_scale );
897
+ }
898
+
899
+ // apply penalties
900
+ if (!last_tokens.empty ()) {
901
+ const float nl_logit = logits[llama_token_nl (ctx)];
902
+ const int last_n_repeat = std::min (std::min ((int )last_tokens.size (), repeat_last_n), n_ctx);
903
+
904
+ llama_sample_repetition_penalty (ctx, &cur_p,
905
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
906
+ last_n_repeat, repeat_penalty);
907
+ llama_sample_frequency_and_presence_penalties (ctx, &cur_p,
908
+ last_tokens.data () + last_tokens.size () - last_n_repeat,
909
+ last_n_repeat, alpha_frequency, alpha_presence);
910
+
911
+ if (!penalize_nl) {
912
+ for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
913
+ if (cur_p.data [idx].id == llama_token_nl (ctx)) {
914
+ cur_p.data [idx].logit = nl_logit;
915
+ break ;
916
+ }
917
+ }
918
+ }
919
+ }
920
+
921
+ if (grammar != NULL ) {
922
+ llama_sample_grammar (ctx, &cur_p, grammar);
923
+ }
924
+
925
+ if (temp <= 0 ) {
926
+ // Greedy sampling
927
+ id = llama_sample_token_greedy (ctx, &cur_p);
928
+ } else {
929
+ if (mirostat == 1 ) {
930
+ static float mirostat_mu = 2 .0f * mirostat_tau;
931
+ const int mirostat_m = 100 ;
932
+ llama_sample_temperature (ctx, &cur_p, temp);
933
+ id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
934
+ } else if (mirostat == 2 ) {
935
+ static float mirostat_mu = 2 .0f * mirostat_tau;
936
+ llama_sample_temperature (ctx, &cur_p, temp);
937
+ id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
938
+ } else {
939
+ // Temperature sampling
940
+ llama_sample_top_k (ctx, &cur_p, top_k, 1 );
941
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
942
+ llama_sample_typical (ctx, &cur_p, typical_p, 1 );
943
+ llama_sample_top_p (ctx, &cur_p, top_p, 1 );
944
+ llama_sample_temperature (ctx, &cur_p, temp);
945
+
946
+ {
947
+ const int n_top = 10 ;
948
+ LOG (" top %d candidates:\n " , n_top);
949
+
950
+ for (int i = 0 ; i < n_top; i++) {
951
+ const llama_token id = cur_p.data [i].id ;
952
+ LOG (" - %5d: '%12s' (%.3f)\n " , id, llama_token_to_piece (ctx, id).c_str (), cur_p.data [i].p );
953
+ }
954
+ }
955
+
956
+ id = llama_sample_token (ctx, &cur_p);
957
+
958
+ LOG (" sampled token: %5d: '%s'\n " , id, llama_token_to_piece (ctx, id).c_str ());
959
+ }
960
+ }
961
+ // printf("`%d`", candidates_p.size);
962
+
963
+ if (grammar != NULL ) {
964
+ llama_grammar_accept_token (ctx, grammar, id);
965
+ }
966
+
967
+ return id;
968
+ }
969
+
970
+ //
971
+ // YAML utils
972
+ //
973
+
835
974
// returns true if successful, false otherwise
836
975
bool create_directory_with_parents (const std::string & path) {
837
976
#ifdef _WIN32
@@ -1070,6 +1209,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1070
1209
fprintf (stream, " mirostat_lr: %f # default: 0.1\n " , params.mirostat_eta );
1071
1210
fprintf (stream, " mlock: %s # default: false\n " , params.use_mlock ? " true" : " false" );
1072
1211
fprintf (stream, " model: %s # default: models/7B/ggml-model.bin\n " , params.model .c_str ());
1212
+ fprintf (stream, " model_draft: %s # default:\n " , params.model_draft .c_str ());
1073
1213
fprintf (stream, " mtest: %s # default: false\n " , params.mem_test ? " true" : " false" );
1074
1214
fprintf (stream, " multiline_input: %s # default: false\n " , params.multiline_input ? " true" : " false" );
1075
1215
fprintf (stream, " n_gpu_layers: %d # default: 0\n " , params.n_gpu_layers );
0 commit comments