Skip to content

Commit 07173e8

Browse files
committed
add ability to use guide tokens for TTS, ref: ggml-org#11186
1 parent bd38665 commit 07173e8

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22152215
params.vocoder.model = value;
22162216
}
22172217
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
2218+
add_opt(common_arg(
2219+
{"--tts-use-guide-tokens"},
2220+
"Use guide tokens to improve TTS word recall",
2221+
[](common_params & params) {
2222+
params.vocoder.use_guide_tokens = true;
2223+
}
2224+
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
22182225

22192226
// model-specific
22202227
add_opt(common_arg(

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ struct common_params_vocoder {
174174

175175
std::string model = ""; // model path // NOLINT
176176
std::string model_url = ""; // model url to download // NOLINT
177+
178+
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
177179
};
178180

179181
struct common_params {

examples/tts/tts.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
425425
prompt_add(prompt, model, "<|im_start|>\n", true, true);
426426
}
427427

428+
static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str)
429+
{
430+
const std::string& delimiter = "<|text_sep|>";
431+
432+
std::vector<llama_token> result;
433+
size_t start = 0;
434+
size_t end = str.find(delimiter);
435+
436+
while (end != std::string::npos) {
437+
std::string current_word = str.substr(start, end - start);
438+
auto tmp = common_tokenize(model, current_word, false, true);
439+
result.push_back(tmp[0]);
440+
start = end + delimiter.length();
441+
end = str.find(delimiter, start);
442+
}
443+
444+
// Add the last part
445+
std::string current_word = str.substr(start);
446+
auto tmp = common_tokenize(model, current_word, false, true);
447+
result.push_back(tmp[0]);
448+
return result;
449+
}
450+
428451
int main(int argc, char ** argv) {
429452
common_params params;
430453

@@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
492515
const auto t_main_start = ggml_time_us();
493516

494517
std::vector<llama_token> codes;
518+
std::vector<llama_token> guide_tokens;
495519

496520
// process prompt and generate voice codes
497521
{
@@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
506530
// convert the input text into the necessary format expected by OuteTTS
507531
{
508532
std::string prompt_clean = process_text(params.prompt);
533+
if(params.vocoder.use_guide_tokens)
534+
{
535+
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
536+
}
509537

510538
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
511539

@@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
715743
int n_past = batch.n_tokens;
716744
int n_decode = 0;
717745

746+
bool next_token_uses_guide_token = true;
747+
718748
while (n_decode <= n_predict) {
719749
// prepare the next batch
720750
common_batch_clear(batch);
@@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
726756
continue;
727757
}
728758

729-
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
759+
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
760+
761+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
762+
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
763+
{
764+
llama_token guide_token = guide_tokens[0];
765+
guide_tokens.erase(guide_tokens.begin());
766+
new_token_id = guide_token; //ensure correct word fragment is used
767+
}
768+
769+
//this is the token id that always precedes a new word
770+
next_token_uses_guide_token = (new_token_id == 198);
730771

731772
common_sampler_accept(smpl[i], new_token_id, true);
732773

src/llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11122,9 +11122,9 @@ static int llama_decode_impl(
1112211122
}
1112311123
}
1112411124

11125-
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
11125+
GGML_ASSERT_CONTINUE(n_tokens_all <= cparams.n_batch);
1112611126

11127-
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
11127+
GGML_ASSERT_CONTINUE((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
1112811128

1112911129
if (lctx.t_compute_start_us == 0) {
1113011130
lctx.t_compute_start_us = ggml_time_us();

0 commit comments

Comments
 (0)