Skip to content

Commit 6d47745

Browse files
authored
Fix bad words logits processor (#1278)
* Skip test if there aren't enough tokens to match the banned sequence * Add unit test * Update logits_process.test.js
1 parent 10c09fb commit 6d47745

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/generation/logits_process.js

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,13 +548,15 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
548548
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
549549
const ids = input_ids[i];
550550
for (const bad_word_ids of this.bad_words_ids) {
551+
// There aren't enough tokens to match the banned sequence
552+
if (ids.length < bad_word_ids.length - 1) continue;
553+
551554
// Whether to modify the logits of the last token in the bad word id sequence
552555
let mark = true;
553556

554557
// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
555558
// then we set the logits of the last bad word id to -Infinity.
556-
for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
557-
559+
for (let j = 1; j <= bad_word_ids.length - 1; ++j) {
558560
// NOTE: We use != instead of !== to compare bigint and number
559561
// @ts-ignore
560562
if (bad_word_ids.at(-j - 1) != ids.at(-j)) {

tests/utils/logits_process.test.js

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,31 @@ describe("Logits Processors", () => {
7979
},
8080
MAX_TEST_EXECUTION_TIME,
8181
);
82+
83+
it(
84+
"different lengths",
85+
async () => {
86+
const text_input = "this is a test";
87+
88+
const generated_text_target = "кт México constructed lake user";
89+
const text_target = [{ generated_text: text_input + generated_text_target }];
90+
91+
const output = await pipe(text_input, {
92+
max_new_tokens: 5,
93+
bad_words_ids: [
94+
// default: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
95+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3931], // should never trigger (longer than input sequence)
96+
97+
// block #1: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
98+
[3931, 14756, 7811],
99+
100+
// result: [445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 1404n]
101+
],
102+
});
103+
compare(output, text_target);
104+
},
105+
MAX_TEST_EXECUTION_TIME,
106+
);
82107
});
83108

84109
afterAll(async () => {

0 commit comments

Comments
 (0)