diff --git a/bindings/python/Makefile b/bindings/python/Makefile index f07cff584..7e151f92b 100644 --- a/bindings/python/Makefile +++ b/bindings/python/Makefile @@ -14,8 +14,8 @@ style: # Check the source code is formatted correctly check-style: python stub.py --check - ruff check examples py_src/tokenizers tests - ruff format --check examples py_src/tokenizers tests + ruff check $(check_dirs) + ruff format --check $(check_dirs) TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json diff --git a/bindings/python/py_src/tokenizers/tools/visualizer.py b/bindings/python/py_src/tokenizers/tools/visualizer.py index b7abb7013..791de498e 100644 --- a/bindings/python/py_src/tokenizers/tools/visualizer.py +++ b/bindings/python/py_src/tokenizers/tools/visualizer.py @@ -241,7 +241,7 @@ def consecutive_chars_to_html( # In this case we are looking at a group/single char that is not tokenized. # e.g. white space css_classes.append("non-token") - css = f'''class="{' '.join(css_classes)}"''' + css = f'''class="{" ".join(css_classes)}"''' data = "" for key, val in data_items.items(): data += f' data-{key}="{val}"' diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 98b7d6b72..4a408ff1d 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -646,11 +646,6 @@ pub struct PyDecodeStream { /// The index within the ids corresponding to the prefix so we can drain /// correctly prefix_index: usize, - /// We need to keep 2 prefixes. - /// Prefix is the second one that was already emitted to discard the part - /// of the text of all the ids - /// read is the prefix kept only for starting side effects of the prefix - read_index: usize, } #[pymethods] @@ -663,7 +658,6 @@ impl PyDecodeStream { ids: vec![], prefix: "".to_string(), prefix_index: 0, - read_index: 0, } } @@ -676,7 +670,6 @@ impl PyDecodeStream { &mut self.ids, &mut self.prefix, &mut self.prefix_index, - &mut self.read_index, )) .into() } diff --git a/tokenizers/Makefile b/tokenizers/Makefile index a407afffc..00e142b17 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -4,7 +4,7 @@ TESTS_DIR = tests dir_guard=@mkdir -p $(@D) -SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json +SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json BENCHMARK_RESOURCES = $(SHARED_RESOURCES) TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json @@ -79,3 +79,7 @@ $(DATA_DIR)/tokenizer-wiki.json : $(DATA_DIR)/bert-wiki.json : $(dir_guard) wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@ + +$(DATA_DIR)/llama-3-tokenizer.json : + $(dir_guard) + wget https://huggingface.co/hf-internal-testing/llama3-tokenizer/resolve/main/tokenizer.json -O $@ diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 6fc8033e3..93b3d9c37 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -199,9 +199,9 @@ impl Word { // Make sure we are not processing an expired queue entry let target_new_pair = (self.symbols[top.pos].c, right.c); - if !merges + if merges .get(&target_new_pair) - .map_or(false, |(_, new_id)| *new_id == top.new_id) + .is_none_or(|(_, new_id)| *new_id != top.new_id) { continue; } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 01394e474..58462331d 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -441,7 +441,7 @@ impl TemplateProcessingBuilder { let exist = self .special_tokens .as_ref() - .map_or(false, |map| map.0.contains_key(sp)); + .is_some_and(|map| map.0.contains_key(sp)); match exist { false => Some(sp), diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index b67681102..808d120d5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1035,11 +1035,6 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> { /// The index within the ids corresponding to the prefix so we can drain /// correctly prefix_index: usize, - /// We need to keep 2 prefixes. - /// Prefix is the second one that was already emitted to discard the part - /// of the text of all the ids - /// read is the prefix kept only for starting side effects of the prefix - read_index: usize, } #[derive(thiserror::Error, Debug)] @@ -1063,7 +1058,6 @@ where skip_special_tokens, prefix: "".to_string(), prefix_index: 0, - read_index: 0, } } @@ -1076,7 +1070,6 @@ where &mut self.ids, &mut self.prefix, &mut self.prefix_index, - &mut self.read_index, ) } } @@ -1089,7 +1082,6 @@ pub fn step_decode_stream( ids: &mut Vec, prefix: &mut String, prefix_index: &mut usize, - read_index: &mut usize, ) -> Result> where M: Model, @@ -1108,7 +1100,6 @@ where let new_prefix_index = ids.len() - *prefix_index; *ids = ids.drain(*prefix_index..).collect(); *prefix = tokenizer.decode(ids, skip_special_tokens)?; - *read_index = *prefix_index; *prefix_index = new_prefix_index; Ok(Some(new_text.to_string())) } else { @@ -1563,112 +1554,3 @@ where Ok(()) } } - -#[cfg(test)] -mod test { - #[cfg(feature = "http")] - #[test] - fn test_decoding_with_added_bpe() { - use crate::{ - normalizers, - pre_tokenizers::split::{Split, SplitPattern}, - AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer, - }; - - let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap(); - tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new())); - tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split( - Split::new( - SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()), - SplitDelimiterBehavior::Isolated, - false, - ) - .unwrap(), - )); - tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]); - let encoded = tokenizer - .encode("Hey! how is this token: 嗎", false) - .unwrap(); - assert_eq!( - encoded.get_ids(), - [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256] - ); - assert_eq!( - encoded.get_tokens(), - ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"] - ); - - let decoded = tokenizer.decode(encoded.get_ids(), false); - assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎"); - - tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]); - let encoded = tokenizer - .encode("Hey! how is this token: д", false) - .unwrap(); - assert_eq!( - encoded.get_ids(), - [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257] - ); - assert_eq!( - encoded.get_tokens(), - ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"] - ); - let decoded = tokenizer.decode(encoded.get_ids(), false); - assert_eq!(decoded.unwrap(), "Hey! how is this token: д") - } - - #[cfg(feature = "http")] - #[test] - fn test_decode_stream_step_no_panic() { - use std::panic; - - use crate::Tokenizer; - - let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap(); - - // "A B C D E F G H I J" - let mut decode_stream = tokenizer.decode_stream(false); - let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622]; - let expected_outputs = vec![ - Some("A".to_string()), - Some(" B".to_string()), - Some(" C".to_string()), - Some(" D".to_string()), - Some(" E".to_string()), - Some(" F".to_string()), - Some(" G".to_string()), - Some(" H".to_string()), - Some(" I".to_string()), - Some(" J".to_string()), - ]; - for (i, &token) in output_tokens.iter().enumerate() { - let maybe_panic = - panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token))); - assert!(maybe_panic.is_ok()); - let result = maybe_panic.unwrap(); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), expected_outputs[i]); - } - - // "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113]) - let mut decode_stream = tokenizer.decode_stream(false); - let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113]; - let expected_outputs = vec![ - None, - Some("삥".to_string()), - None, - None, - Some("뽕".to_string()), - None, - Some("빵".to_string()), - ]; - for (i, &token) in output_tokens.iter().enumerate() { - let maybe_panic = - panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token))); - assert!(maybe_panic.is_ok()); - let result = maybe_panic.unwrap(); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), expected_outputs[i]); - } - } -} diff --git a/tokenizers/tests/stream.rs b/tokenizers/tests/stream.rs new file mode 100644 index 000000000..c4cfee3dd --- /dev/null +++ b/tokenizers/tests/stream.rs @@ -0,0 +1,78 @@ +use tokenizers::{ + normalizers, + pre_tokenizers::split::{Split, SplitPattern}, + AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer, +}; + +#[test] +fn test_decoding_with_added_bpe() { + let mut tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap(); + tokenizer.with_normalizer(Some(NormalizerWrapper::from(normalizers::ByteLevel::new()))); + tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Split( + Split::new( + SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()), + SplitDelimiterBehavior::Isolated, + false, + ) + .unwrap(), + ))); + tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]); + let encoded = tokenizer + .encode("Hey! how is this token: 嗎", false) + .unwrap(); + assert_eq!( + encoded.get_ids(), + [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256] + ); + assert_eq!( + encoded.get_tokens(), + ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"] + ); + + let decoded = tokenizer.decode(encoded.get_ids(), false); + assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎"); + + tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]); + let encoded = tokenizer + .encode("Hey! how is this token: д", false) + .unwrap(); + assert_eq!( + encoded.get_ids(), + [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257] + ); + assert_eq!( + encoded.get_tokens(), + ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"] + ); + let decoded = tokenizer.decode(encoded.get_ids(), false); + assert_eq!(decoded.unwrap(), "Hey! how is this token: д") +} + +#[test] +fn test_decode_stream_step_no_panic() { + let tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap(); + + // "A B C D E F G H I J" + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string())); + assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string())); + assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string())); + assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string())); + assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string())); + assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string())); + assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string())); + assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string())); + assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string())); + assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string())); + // for (i, &token) in output_tokens.iter().enumerate() {} + + // "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113]) + let mut decode_stream = tokenizer.decode_stream(false); + assert_eq!(decode_stream.step(80690).unwrap(), None); + assert_eq!(decode_stream.step(98).unwrap(), Some("삥".to_string())); + assert_eq!(decode_stream.step(167).unwrap(), None); + assert_eq!(decode_stream.step(121).unwrap(), None); + assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".to_string())); + assert_eq!(decode_stream.step(102457).unwrap(), None); + assert_eq!(decode_stream.step(113).unwrap(), Some("빵".to_string())); +}