19
19
20
20
using json = nlohmann::ordered_json;
21
21
22
+ enum outetts_version {
23
+ OUTETTS_V0_2,
24
+ OUTETTS_V0_3,
25
+ };
26
+
22
27
//
23
28
// Terminal utils
24
29
//
@@ -374,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
374
379
}
375
380
376
381
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
377
- static std::string process_text (const std::string & text) {
382
+ static std::string process_text (const std::string & text, const outetts_version tts_version = OUTETTS_V0_2 ) {
378
383
379
384
// For now I skipped text romanization as I am unsure how to handle
380
385
// uroman and MeCab implementations in C++
@@ -404,7 +409,8 @@ static std::string process_text(const std::string & text) {
404
409
if (c == ' ') {
405
410
prompt_clean += "<|text_sep|>";
406
411
*/
407
- processed_text = std::regex_replace (processed_text, std::regex (R"( \s)" ), " <|text_sep|>" );
412
+ std::string separator = (tts_version == OUTETTS_V0_3) ? " <|space|>" : " <|text_sep|>" ;
413
+ processed_text = std::regex_replace (processed_text, std::regex (R"( \s)" ), separator);
408
414
409
415
return processed_text;
410
416
}
@@ -428,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
428
434
prompt_add (prompt, vocab, " <|im_start|>\n " , true , true );
429
435
}
430
436
431
- static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str) {
432
- const std::string& delimiter = " <|text_sep|>" ;
437
+ static std::vector<llama_token> prepare_guide_tokens (const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2 ) {
438
+ const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? " <|space|> " : " <| text_sep|>" ) ;
433
439
434
440
std::vector<llama_token> result;
435
441
size_t start = 0 ;
@@ -466,31 +472,62 @@ static json speaker_from_file(const std::string & speaker_file) {
466
472
return speaker;
467
473
}
468
474
469
- static std::string audio_text_from_speaker (json speaker) {
475
+ static outetts_version get_tts_version (llama_model *model, json speaker = json::object()) {
476
+ if (speaker.contains (" version" )) {
477
+ std::string version = speaker[" version" ].get <std::string>();
478
+ if (version == " 0.2" ) {
479
+ return OUTETTS_V0_2;
480
+ } else if (version == " 0.3" ) {
481
+ return OUTETTS_V0_3;
482
+ } else {
483
+ LOG_ERR (" %s: Unsupported speaker version '%s'\n " , __func__, version.c_str ());
484
+ }
485
+ }
486
+
487
+ // Also could get version from model itself
488
+ const char *chat_template = llama_model_chat_template (model, nullptr );
489
+ if (chat_template && std::string (chat_template) == " outetts-0.3" ) {
490
+ return OUTETTS_V0_3;
491
+ }
492
+
493
+ // Use 0.2 as the default version
494
+ return OUTETTS_V0_2;
495
+ }
496
+
497
+ static std::string audio_text_from_speaker (json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
470
498
std::string audio_text = " <|text_start|>" ;
471
- for (const auto &word : speaker[" words" ]) {
472
- audio_text += word[" word" ].get <std::string>() + " <|text_sep|>" ;
499
+
500
+ if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
501
+ std::string separator = (tts_version == OUTETTS_V0_3) ? " <|space|>" : " <|text_sep|>" ;
502
+ for (const auto &word : speaker[" words" ]) {
503
+ audio_text += word[" word" ].get <std::string>() + separator;
504
+ }
473
505
}
474
506
475
507
return audio_text;
476
508
}
477
509
478
- static std::string audio_data_from_speaker (json speaker) {
510
+ static std::string audio_data_from_speaker (json speaker, const outetts_version tts_version = OUTETTS_V0_2 ) {
479
511
std::string audio_data = " <|audio_start|>\n " ;
480
- for (const auto &word : speaker[" words" ]) {
481
- std::string word_text = word[" word" ].get <std::string>();
482
- double duration = word[" duration" ].get <double >();
483
- std::vector<int > codes = word[" codes" ].get <std::vector<int >>();
484
-
485
- // Create the audio output entry
486
- std::ostringstream word_entry;
487
- word_entry << word_text << " <|t_" << std::fixed << std::setprecision (2 )
488
- << duration << " |><|code_start|>" ;
489
- for (const auto &Code : codes) {
490
- word_entry << " <|" << Code << " |>" ;
512
+
513
+ if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
514
+ std::string code_start = (tts_version == OUTETTS_V0_3) ? " " : " <|code_start|>" ;
515
+ std::string code_end = (tts_version == OUTETTS_V0_3) ? " <|space|>" : " <|code_end|>" ;
516
+ for (const auto &word : speaker[" words" ]) {
517
+ std::string word_text = word[" word" ].get <std::string>();
518
+ double duration = word[" duration" ].get <double >();
519
+ std::vector<int > codes = word[" codes" ].get <std::vector<int >>();
520
+
521
+ // Create the audio output entry
522
+ std::ostringstream word_entry;
523
+ word_entry << word_text << " <|t_" << std::fixed << std::setprecision (2 )
524
+ << duration << " |>" + code_start;
525
+ for (const auto &Code : codes) {
526
+ word_entry << " <|" << Code << " |>" ;
527
+ }
528
+ word_entry << code_end << " \n " ;
529
+ audio_data += word_entry.str ();
491
530
}
492
- word_entry << " <|code_end|>\n " ;
493
- audio_data += word_entry.str ();
494
531
}
495
532
496
533
return audio_data;
@@ -601,6 +638,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
601
638
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
602
639
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)" ;
603
640
641
+ // audio data for 0.3 version
642
+ outetts_version tts_version = get_tts_version (model_ttc);
643
+ if (tts_version == OUTETTS_V0_3) {
644
+ audio_text = std::regex_replace (audio_text, std::regex (R"( <\|text_sep\|>)" ), " <|space|>" );
645
+ audio_data = std::regex_replace (audio_data, std::regex (R"( <\|code_start\|>)" ), " " );
646
+ audio_data = std::regex_replace (audio_data, std::regex (R"( <\|code_end\|>)" ), " <|space|>" );
647
+ }
648
+
604
649
// load speaker if given
605
650
if (!params.vocoder .speaker_file .empty ()) {
606
651
LOG_INF (" %s: loading speaker ..\n " , __func__);
@@ -609,8 +654,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
609
654
LOG_ERR (" %s: Failed to load speaker file '%s'\n " , __func__, params.vocoder .speaker_file .c_str ());
610
655
return 1 ;
611
656
}
612
- audio_text = audio_text_from_speaker (speaker);
613
- audio_data = audio_data_from_speaker (speaker);
657
+ audio_text = audio_text_from_speaker (speaker, tts_version );
658
+ audio_data = audio_data_from_speaker (speaker, tts_version );
614
659
}
615
660
616
661
// process prompt and generate voice codes
@@ -625,9 +670,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
625
670
626
671
// convert the input text into the necessary format expected by OuteTTS
627
672
{
628
- std::string prompt_clean = process_text (params.prompt );
673
+ std::string prompt_clean = process_text (params.prompt , tts_version );
629
674
if (params.vocoder .use_guide_tokens ) {
630
- guide_tokens = prepare_guide_tokens (vocab, prompt_clean);
675
+ guide_tokens = prepare_guide_tokens (vocab, prompt_clean, tts_version );
631
676
}
632
677
633
678
LOG_INF (" %s: prompt: '%s'\n " , __func__, prompt_clean.c_str ());
@@ -641,15 +686,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
641
686
prompt_add (prompt_inp, vocab, audio_data, false , true );
642
687
} else {
643
688
// disabled to save time on tokenizing each time
644
- #if 0
689
+ #if 1
645
690
const std::string voice_data = audio_data;
646
691
647
692
auto tmp = common_tokenize (vocab, voice_data, false , true );
648
693
printf (" \n\n " );
649
- for (int i = 0; i < tmp.size(); ++i) {
694
+ for (size_t i = 0 ; i < tmp.size (); ++i) {
650
695
printf (" %d, " , tmp[i]);
651
696
}
652
697
printf (" \n\n " );
698
+ prompt_add (prompt_inp, tmp);
653
699
#else
654
700
prompt_add(prompt_inp, llama_tokens {
655
701
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
0 commit comments