@@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
425
425
prompt_add (prompt, model, " <|im_start|>\n " , true , true );
426
426
}
427
427
428
+ static std::vector<llama_token> prepare_guide_tokens (const llama_model * model, const std::string& str)
429
+ {
430
+ const std::string& delimiter = " <|text_sep|>" ;
431
+
432
+ std::vector<llama_token> result;
433
+ size_t start = 0 ;
434
+ size_t end = str.find (delimiter);
435
+
436
+ while (end != std::string::npos) {
437
+ std::string current_word = str.substr (start, end - start);
438
+ auto tmp = common_tokenize (model, current_word, false , true );
439
+ result.push_back (tmp[0 ]);
440
+ start = end + delimiter.length ();
441
+ end = str.find (delimiter, start);
442
+ }
443
+
444
+ // Add the last part
445
+ std::string current_word = str.substr (start);
446
+ auto tmp = common_tokenize (model, current_word, false , true );
447
+ result.push_back (tmp[0 ]);
448
+ return result;
449
+ }
450
+
428
451
int main (int argc, char ** argv) {
429
452
common_params params;
430
453
@@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
492
515
const auto t_main_start = ggml_time_us ();
493
516
494
517
std::vector<llama_token> codes;
518
+ std::vector<llama_token> guide_tokens;
495
519
496
520
// process prompt and generate voice codes
497
521
{
@@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
506
530
// convert the input text into the necessary format expected by OuteTTS
507
531
{
508
532
std::string prompt_clean = process_text (params.prompt );
533
+ if (params.vocoder .use_guide_tokens )
534
+ {
535
+ guide_tokens = prepare_guide_tokens (model_ttc,prompt_clean);
536
+ }
509
537
510
538
LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
511
539
@@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
715
743
int n_past = batch.n_tokens ;
716
744
int n_decode = 0 ;
717
745
746
+ bool next_token_uses_guide_token = true ;
747
+
718
748
while (n_decode <= n_predict) {
719
749
// prepare the next batch
720
750
common_batch_clear (batch);
@@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
726
756
continue ;
727
757
}
728
758
729
- const llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
759
+ llama_token new_token_id = common_sampler_sample (smpl[i], ctx_ttc, i_batch[i]);
760
+
761
+ // guide tokens help prevent hallucinations by forcing the TTS to use the correct word
762
+ if (!guide_tokens.empty () && next_token_uses_guide_token && !llama_token_is_control (model_ttc, new_token_id) && !llama_token_is_eog (model_ttc, new_token_id))
763
+ {
764
+ llama_token guide_token = guide_tokens[0 ];
765
+ guide_tokens.erase (guide_tokens.begin ());
766
+ new_token_id = guide_token; // ensure correct word fragment is used
767
+ }
768
+
769
+ // this is the token id that always precedes a new word
770
+ next_token_uses_guide_token = (new_token_id == 198 );
730
771
731
772
common_sampler_accept (smpl[i], new_token_id, true );
732
773
0 commit comments