Skip to content

Commit 986ade7

Browse files
committed
tts: handle outetts-0.3
1 parent 0b9db0a commit 986ade7

File tree

1 file changed

+73
-27
lines changed

1 file changed

+73
-27
lines changed

examples/tts/tts.cpp

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
using json = nlohmann::ordered_json;
2121

22+
enum outetts_version {
23+
OUTETTS_V0_2,
24+
OUTETTS_V0_3,
25+
};
26+
2227
//
2328
// Terminal utils
2429
//
@@ -374,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
374379
}
375380

376381
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
377-
static std::string process_text(const std::string & text) {
382+
static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
378383

379384
// For now I skipped text romanization as I am unsure how to handle
380385
// uroman and MeCab implementations in C++
@@ -404,7 +409,8 @@ static std::string process_text(const std::string & text) {
404409
if (c == ' ') {
405410
prompt_clean += "<|text_sep|>";
406411
*/
407-
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
412+
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
413+
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
408414

409415
return processed_text;
410416
}
@@ -428,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
428434
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
429435
}
430436

431-
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
432-
const std::string& delimiter = "<|text_sep|>";
437+
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
438+
const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
433439

434440
std::vector<llama_token> result;
435441
size_t start = 0;
@@ -466,31 +472,62 @@ static json speaker_from_file(const std::string & speaker_file) {
466472
return speaker;
467473
}
468474

469-
static std::string audio_text_from_speaker(json speaker) {
475+
static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
476+
if (speaker.contains("version")) {
477+
std::string version = speaker["version"].get<std::string>();
478+
if (version == "0.2") {
479+
return OUTETTS_V0_2;
480+
} else if (version == "0.3") {
481+
return OUTETTS_V0_3;
482+
} else {
483+
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
484+
}
485+
}
486+
487+
// Also could get version from model itself
488+
const char *chat_template = llama_model_chat_template(model, nullptr);
489+
if (chat_template && std::string(chat_template) == "outetts-0.3") {
490+
return OUTETTS_V0_3;
491+
}
492+
493+
// Use 0.2 as the default version
494+
return OUTETTS_V0_2;
495+
}
496+
497+
static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
470498
std::string audio_text = "<|text_start|>";
471-
for (const auto &word : speaker["words"]) {
472-
audio_text += word["word"].get<std::string>() + "<|text_sep|>";
499+
500+
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
501+
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
502+
for (const auto &word : speaker["words"]) {
503+
audio_text += word["word"].get<std::string>() + separator;
504+
}
473505
}
474506

475507
return audio_text;
476508
}
477509

478-
static std::string audio_data_from_speaker(json speaker) {
510+
static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
479511
std::string audio_data = "<|audio_start|>\n";
480-
for (const auto &word : speaker["words"]) {
481-
std::string word_text = word["word"].get<std::string>();
482-
double duration = word["duration"].get<double>();
483-
std::vector<int> codes = word["codes"].get<std::vector<int>>();
484-
485-
// Create the audio output entry
486-
std::ostringstream word_entry;
487-
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
488-
<< duration << "|><|code_start|>";
489-
for (const auto &Code : codes) {
490-
word_entry << "<|" << Code << "|>";
512+
513+
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
514+
std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
515+
std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
516+
for (const auto &word : speaker["words"]) {
517+
std::string word_text = word["word"].get<std::string>();
518+
double duration = word["duration"].get<double>();
519+
std::vector<int> codes = word["codes"].get<std::vector<int>>();
520+
521+
// Create the audio output entry
522+
std::ostringstream word_entry;
523+
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
524+
<< duration << "|>" + code_start;
525+
for (const auto &Code : codes) {
526+
word_entry << "<|" << Code << "|>";
527+
}
528+
word_entry << code_end << "\n";
529+
audio_data += word_entry.str();
491530
}
492-
word_entry << "<|code_end|>\n";
493-
audio_data += word_entry.str();
494531
}
495532

496533
return audio_data;
@@ -601,6 +638,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
601638
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
602639
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
603640

641+
// audio data for 0.3 version
642+
outetts_version tts_version = get_tts_version(model_ttc);
643+
if (tts_version == OUTETTS_V0_3) {
644+
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
645+
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
646+
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
647+
}
648+
604649
// load speaker if given
605650
if (!params.vocoder.speaker_file.empty()) {
606651
LOG_INF("%s: loading speaker ..\n", __func__);
@@ -609,8 +654,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
609654
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
610655
return 1;
611656
}
612-
audio_text = audio_text_from_speaker(speaker);
613-
audio_data = audio_data_from_speaker(speaker);
657+
audio_text = audio_text_from_speaker(speaker, tts_version);
658+
audio_data = audio_data_from_speaker(speaker, tts_version);
614659
}
615660

616661
// process prompt and generate voice codes
@@ -625,9 +670,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
625670

626671
// convert the input text into the necessary format expected by OuteTTS
627672
{
628-
std::string prompt_clean = process_text(params.prompt);
673+
std::string prompt_clean = process_text(params.prompt, tts_version);
629674
if (params.vocoder.use_guide_tokens) {
630-
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
675+
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
631676
}
632677

633678
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@@ -641,15 +686,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
641686
prompt_add(prompt_inp, vocab, audio_data, false, true);
642687
} else {
643688
// disabled to save time on tokenizing each time
644-
#if 0
689+
#if 1
645690
const std::string voice_data = audio_data;
646691

647692
auto tmp = common_tokenize(vocab, voice_data, false, true);
648693
printf("\n\n");
649-
for (int i = 0; i < tmp.size(); ++i) {
694+
for (size_t i = 0; i < tmp.size(); ++i) {
650695
printf("%d, ", tmp[i]);
651696
}
652697
printf("\n\n");
698+
prompt_add(prompt_inp, tmp);
653699
#else
654700
prompt_add(prompt_inp, llama_tokens {
655701
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,

0 commit comments

Comments
 (0)