Skip to content

Commit 5c2aad7

Browse files
committed
speculative : print encoding speed
1 parent c33cd8a commit 5c2aad7

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

examples/speculative/speculative.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ int main(int argc, char ** argv) {
4747
params.model = params.model_draft;
4848
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
4949

50+
{
51+
LOG("warming up the models with an empty run\n");
52+
53+
const std::vector<llama_token> tmp = { llama_token_bos(ctx_tgt), };
54+
llama_eval(ctx_tgt, tmp.data(), tmp.size(), 0, params.n_threads);
55+
llama_eval(ctx_dft, tmp.data(), tmp.size(), 0, params.n_threads);
56+
llama_reset_timings(ctx_tgt);
57+
llama_reset_timings(ctx_dft);
58+
}
59+
5060
// tokenize the prompt
5161
std::vector<llama_token> inp;
5262
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
@@ -67,11 +77,17 @@ int main(int argc, char ** argv) {
6777

6878
fflush(stderr);
6979

80+
const int n_input = inp.size();
81+
82+
const auto t_enc_start = ggml_time_us();
83+
7084
// eval the prompt with both models
7185
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
7286
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
7387
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
7488

89+
const auto t_enc_end = ggml_time_us();
90+
7591
// the 2 models should have the same vocab
7692
const int n_ctx = llama_n_ctx(ctx_tgt);
7793
const int n_vocab = llama_n_vocab(ctx_tgt);
@@ -103,7 +119,7 @@ int main(int argc, char ** argv) {
103119
// used to determine end of generation
104120
bool has_eos = false;
105121

106-
const auto t_gen_start = ggml_time_us();
122+
const auto t_dec_start = ggml_time_us();
107123

108124
while (true) {
109125
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
@@ -193,11 +209,12 @@ int main(int argc, char ** argv) {
193209
drafted.erase(drafted.begin());
194210
}
195211

196-
auto t_gen_end = ggml_time_us();
212+
auto t_dec_end = ggml_time_us();
197213

198214
LOG_TEE("\n\n");
199215

200-
LOG_TEE("generated %d tokens in %.3f seconds, speed: %.3f t/s\n", n_predict, (t_gen_end - t_gen_start) / 1e6f, n_predict / ((t_gen_end - t_gen_start) / 1e6f));
216+
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
217+
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
201218

202219
// TODO: make sure these numbers are computed correctly
203220
LOG_TEE("\n");

0 commit comments

Comments
 (0)