@@ -374,7 +374,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
374
374
}
375
375
376
376
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
377
- static std::string process_text (const std::string & text) {
377
+ static std::string process_text (const std::string & text, const std::string & tts_version = " 0.2 " ) {
378
378
379
379
// For now I skipped text romanization as I am unsure how to handle
380
380
// uroman and MeCab implementations in C++
@@ -404,7 +404,8 @@ static std::string process_text(const std::string & text) {
404
404
if (c == ' ') {
405
405
prompt_clean += "<|text_sep|>";
406
406
*/
407
- processed_text = std::regex_replace (processed_text, std::regex (R"( \s)" ), " <|text_sep|>" );
407
+ std::string separator = (tts_version == " 0.3" ) ? " <|space|>" : " <|text_sep|>" ;
408
+ processed_text = std::regex_replace (processed_text, std::regex (R"( \s)" ), separator);
408
409
409
410
return processed_text;
410
411
}
@@ -428,8 +429,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
428
429
prompt_add (prompt, vocab, " <|im_start|>\n " , true , true );
429
430
}
430
431
431
- static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str) {
432
- const std::string& delimiter = " <| text_sep|>" ;
432
+ static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str, const std::string & tts_version = " 0.2 " ) {
433
+ const std::string& delimiter = (tts_version == " 0.3 " ? " <|space|> " : " <| text_sep|>" ) ;
433
434
434
435
std::vector<llama_token> result;
435
436
size_t start = 0 ;
@@ -466,31 +467,59 @@ static json speaker_from_file(const std::string & speaker_file) {
466
467
return speaker;
467
468
}
468
469
469
- static std::string audio_text_from_speaker (json speaker) {
470
+ static std::string get_tts_version (llama_model *model, json speaker = json::object()) {
471
+ if (speaker.contains (" version" )) {
472
+ return speaker[" version" ].get <std::string>();
473
+ }
474
+
475
+ // Also could get version from model itself
476
+ const char *chat_template = llama_model_chat_template (model, nullptr );
477
+ if (chat_template && std::string (chat_template) == " outetts-0.3" ) {
478
+ return " 0.3" ;
479
+ }
480
+
481
+ // Use 0.2 as the default version
482
+ return " 0.2" ;
483
+ }
484
+
485
+ static std::string audio_text_from_speaker (json speaker, const std::string & tts_version = " 0.2" ) {
470
486
std::string audio_text = " <|text_start|>" ;
471
- for (const auto &word : speaker[" words" ]) {
472
- audio_text += word[" word" ].get <std::string>() + " <|text_sep|>" ;
487
+
488
+ if (tts_version == " 0.2" || tts_version == " 0.3" ) {
489
+ std::string separator = (tts_version == " 0.3" ) ? " <|space|>" : " <|text_sep|>" ;
490
+ for (const auto &word : speaker[" words" ]) {
491
+ audio_text += word[" word" ].get <std::string>() + separator;
492
+ }
493
+ } else {
494
+ LOG_ERR (" %s: Unsupported speaker version '%s'\n " , __func__, tts_version.c_str ());
473
495
}
474
496
475
497
return audio_text;
476
498
}
477
499
478
- static std::string audio_data_from_speaker (json speaker) {
500
+ static std::string audio_data_from_speaker (json speaker, const std::string & tts_version = " 0.2 " ) {
479
501
std::string audio_data = " <|audio_start|>\n " ;
480
- for (const auto &word : speaker[" words" ]) {
481
- std::string word_text = word[" word" ].get <std::string>();
482
- double duration = word[" duration" ].get <double >();
483
- std::vector<int > codes = word[" codes" ].get <std::vector<int >>();
484
-
485
- // Create the audio output entry
486
- std::ostringstream word_entry;
487
- word_entry << word_text << " <|t_" << std::fixed << std::setprecision (2 )
488
- << duration << " |><|code_start|>" ;
489
- for (const auto &Code : codes) {
490
- word_entry << " <|" << Code << " |>" ;
502
+
503
+ if (tts_version == " 0.2" || tts_version == " 0.3" ) {
504
+ std::string code_start = (tts_version == " 0.3" ) ? " " : " <|code_start|>" ;
505
+ std::string code_end = (tts_version == " 0.3" ) ? " <|space|>" : " <|code_end|>" ;
506
+ for (const auto &word : speaker[" words" ]) {
507
+ std::string word_text = word[" word" ].get <std::string>();
508
+ double duration = word[" duration" ].get <double >();
509
+ std::vector<int > codes = word[" codes" ].get <std::vector<int >>();
510
+
511
+ // Create the audio output entry
512
+ std::ostringstream word_entry;
513
+ word_entry << word_text << " <|t_" << std::fixed << std::setprecision (2 )
514
+ << duration << " |>" + code_start;
515
+ for (const auto &Code : codes) {
516
+ word_entry << " <|" << Code << " |>" ;
517
+ }
518
+ word_entry << code_end << " \n " ;
519
+ audio_data += word_entry.str ();
491
520
}
492
- word_entry << " <|code_end|> \n " ;
493
- audio_data += word_entry. str ( );
521
+ } else {
522
+ LOG_ERR ( " %s: Unsupported speaker version '%s' \n " , __func__, tts_version. c_str () );
494
523
}
495
524
496
525
return audio_data;
@@ -601,6 +630,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
601
630
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
602
631
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)" ;
603
632
633
+ // audio data for 0.3 version
634
+ std::string tts_version = get_tts_version (model_ttc);
635
+ if (tts_version == " 0.3" ) {
636
+ audio_text = std::regex_replace (audio_text, std::regex (R"( <\|text_sep\|>)" ), " <|space|>" );
637
+ audio_data = std::regex_replace (audio_data, std::regex (R"( <\|code_start\|>)" ), " " );
638
+ audio_data = std::regex_replace (audio_data, std::regex (R"( <\|code_end\|>)" ), " <|space|>" );
639
+ }
640
+
604
641
// load speaker if given
605
642
if (!params.vocoder .speaker_file .empty ()) {
606
643
LOG_INF (" %s: loading speaker ..\n " , __func__);
@@ -609,8 +646,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
609
646
LOG_ERR (" %s: Failed to load speaker file '%s'\n " , __func__, params.vocoder .speaker_file .c_str ());
610
647
return 1 ;
611
648
}
612
- audio_text = audio_text_from_speaker (speaker);
613
- audio_data = audio_data_from_speaker (speaker);
649
+ audio_text = audio_text_from_speaker (speaker, tts_version );
650
+ audio_data = audio_data_from_speaker (speaker, tts_version );
614
651
}
615
652
616
653
// process prompt and generate voice codes
@@ -625,9 +662,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
625
662
626
663
// convert the input text into the necessary format expected by OuteTTS
627
664
{
628
- std::string prompt_clean = process_text (params.prompt );
665
+ std::string prompt_clean = process_text (params.prompt , tts_version );
629
666
if (params.vocoder .use_guide_tokens ) {
630
- guide_tokens = prepare_guide_tokens (vocab, prompt_clean);
667
+ guide_tokens = prepare_guide_tokens (vocab, prompt_clean, tts_version );
631
668
}
632
669
633
670
LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
@@ -641,15 +678,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
641
678
prompt_add (prompt_inp, vocab, audio_data, false , true );
642
679
} else {
643
680
// disabled to save time on tokenizing each time
644
- #if 0
681
+ #if 1
645
682
const std::string voice_data = audio_data;
646
683
647
684
auto tmp = common_tokenize (vocab, voice_data, false , true );
648
685
printf (" \n\n " );
649
- for (int i = 0; i < tmp.size(); ++i) {
686
+ for (size_t i = 0 ; i < tmp.size (); ++i) {
650
687
printf (" %d, " , tmp[i]);
651
688
}
652
689
printf (" \n\n " );
690
+ prompt_add (prompt_inp, tmp);
653
691
#else
654
692
prompt_add(prompt_inp, llama_tokens {
655
693
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
0 commit comments