Skip to content

Commit 888f57e

Browse files
committed
tts: handle outetts-0.3
1 parent 0b9db0a commit 888f57e

File tree

1 file changed

+65
-27
lines changed

1 file changed

+65
-27
lines changed

examples/tts/tts.cpp

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
374374
}
375375

376376
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
377-
static std::string process_text(const std::string & text) {
377+
static std::string process_text(const std::string & text, const std::string & tts_version = "0.2") {
378378

379379
// For now I skipped text romanization as I am unsure how to handle
380380
// uroman and MeCab implementations in C++
@@ -404,7 +404,8 @@ static std::string process_text(const std::string & text) {
404404
if (c == ' ') {
405405
prompt_clean += "<|text_sep|>";
406406
*/
407-
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
407+
std::string separator = (tts_version == "0.3") ? "<|space|>" : "<|text_sep|>";
408+
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
408409

409410
return processed_text;
410411
}
@@ -428,8 +429,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
428429
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
429430
}
430431

431-
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
432-
const std::string& delimiter = "<|text_sep|>";
432+
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string & tts_version = "0.2") {
433+
const std::string& delimiter = (tts_version == "0.3" ? "<|space|>" : "<|text_sep|>");
433434

434435
std::vector<llama_token> result;
435436
size_t start = 0;
@@ -466,31 +467,59 @@ static json speaker_from_file(const std::string & speaker_file) {
466467
return speaker;
467468
}
468469

469-
static std::string audio_text_from_speaker(json speaker) {
470+
static std::string get_tts_version(llama_model *model, json speaker = json::object()) {
471+
if (speaker.contains("version")) {
472+
return speaker["version"].get<std::string>();
473+
}
474+
475+
// Also could get version from model itself
476+
const char *chat_template = llama_model_chat_template(model, nullptr);
477+
if (chat_template && std::string(chat_template) == "outetts-0.3") {
478+
return "0.3";
479+
}
480+
481+
// Use 0.2 as the default version
482+
return "0.2";
483+
}
484+
485+
static std::string audio_text_from_speaker(json speaker, const std::string & tts_version = "0.2") {
470486
std::string audio_text = "<|text_start|>";
471-
for (const auto &word : speaker["words"]) {
472-
audio_text += word["word"].get<std::string>() + "<|text_sep|>";
487+
488+
if (tts_version == "0.2" || tts_version == "0.3") {
489+
std::string separator = (tts_version == "0.3") ? "<|space|>" : "<|text_sep|>";
490+
for (const auto &word : speaker["words"]) {
491+
audio_text += word["word"].get<std::string>() + separator;
492+
}
493+
} else {
494+
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, tts_version.c_str());
473495
}
474496

475497
return audio_text;
476498
}
477499

478-
static std::string audio_data_from_speaker(json speaker) {
500+
static std::string audio_data_from_speaker(json speaker, const std::string & tts_version = "0.2") {
479501
std::string audio_data = "<|audio_start|>\n";
480-
for (const auto &word : speaker["words"]) {
481-
std::string word_text = word["word"].get<std::string>();
482-
double duration = word["duration"].get<double>();
483-
std::vector<int> codes = word["codes"].get<std::vector<int>>();
484-
485-
// Create the audio output entry
486-
std::ostringstream word_entry;
487-
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
488-
<< duration << "|><|code_start|>";
489-
for (const auto &Code : codes) {
490-
word_entry << "<|" << Code << "|>";
502+
503+
if (tts_version == "0.2" || tts_version == "0.3") {
504+
std::string code_start = (tts_version == "0.3") ? "" : "<|code_start|>";
505+
std::string code_end = (tts_version == "0.3") ? "<|space|>" : "<|code_end|>";
506+
for (const auto &word : speaker["words"]) {
507+
std::string word_text = word["word"].get<std::string>();
508+
double duration = word["duration"].get<double>();
509+
std::vector<int> codes = word["codes"].get<std::vector<int>>();
510+
511+
// Create the audio output entry
512+
std::ostringstream word_entry;
513+
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
514+
<< duration << "|>" + code_start;
515+
for (const auto &Code : codes) {
516+
word_entry << "<|" << Code << "|>";
517+
}
518+
word_entry << code_end << "\n";
519+
audio_data += word_entry.str();
491520
}
492-
word_entry << "<|code_end|>\n";
493-
audio_data += word_entry.str();
521+
} else {
522+
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, tts_version.c_str());
494523
}
495524

496525
return audio_data;
@@ -601,6 +630,14 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
601630
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
602631
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
603632

633+
// audio data for 0.3 version
634+
std::string tts_version = get_tts_version(model_ttc);
635+
if (tts_version == "0.3") {
636+
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
637+
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
638+
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
639+
}
640+
604641
// load speaker if given
605642
if (!params.vocoder.speaker_file.empty()) {
606643
LOG_INF("%s: loading speaker ..\n", __func__);
@@ -609,8 +646,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
609646
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
610647
return 1;
611648
}
612-
audio_text = audio_text_from_speaker(speaker);
613-
audio_data = audio_data_from_speaker(speaker);
649+
audio_text = audio_text_from_speaker(speaker, tts_version);
650+
audio_data = audio_data_from_speaker(speaker, tts_version);
614651
}
615652

616653
// process prompt and generate voice codes
@@ -625,9 +662,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
625662

626663
// convert the input text into the necessary format expected by OuteTTS
627664
{
628-
std::string prompt_clean = process_text(params.prompt);
665+
std::string prompt_clean = process_text(params.prompt, tts_version);
629666
if (params.vocoder.use_guide_tokens) {
630-
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
667+
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
631668
}
632669

633670
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@@ -641,15 +678,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
641678
prompt_add(prompt_inp, vocab, audio_data, false, true);
642679
} else {
643680
// disabled to save time on tokenizing each time
644-
#if 0
681+
#if 1
645682
const std::string voice_data = audio_data;
646683

647684
auto tmp = common_tokenize(vocab, voice_data, false, true);
648685
printf("\n\n");
649-
for (int i = 0; i < tmp.size(); ++i) {
686+
for (size_t i = 0; i < tmp.size(); ++i) {
650687
printf("%d, ", tmp[i]);
651688
}
652689
printf("\n\n");
690+
prompt_add(prompt_inp, tmp);
653691
#else
654692
prompt_add(prompt_inp, llama_tokens {
655693
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,

0 commit comments

Comments
 (0)