Skip to content

Commit 2868d4d

Browse files
committed
[BUG] ignore trigrams with null terminator byte when constructing full text index
1 parent a43b1f3 commit 2868d4d

File tree

2 files changed

+117
-28
lines changed

2 files changed

+117
-28
lines changed

rust/index/src/fulltext/types.rs

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::util::TokenInstance;
2+
use super::util::TokenInstanceEncodeError;
23
use chroma_blockstore::{BlockfileFlusher, BlockfileReader, BlockfileWriter};
34
use chroma_error::{ChromaError, ErrorCodes};
45
use futures::StreamExt;
@@ -85,11 +86,16 @@ impl FullTextIndexWriter {
8586
.clone()
8687
.token_stream(new_document)
8788
.process(&mut |token| {
88-
token_instances.push(TokenInstance::encode(
89+
match TokenInstance::encode(
8990
token.text.as_str(),
9091
offset_id,
9192
Some(token.offset_from as u32),
92-
));
93+
) {
94+
Ok(encoded) => token_instances.push(encoded),
95+
Err(TokenInstanceEncodeError::NullTerminator) => {
96+
// ignore
97+
}
98+
}
9399
});
94100
}
95101

@@ -104,29 +110,46 @@ impl FullTextIndexWriter {
104110
.clone()
105111
.token_stream(old_document)
106112
.process(&mut |token| {
107-
trigrams_to_delete.insert(TokenInstance::encode(
113+
match TokenInstance::encode(
108114
token.text.as_str(),
109115
offset_id,
110-
None,
111-
));
116+
Some(token.offset_from as u32),
117+
) {
118+
Ok(encoded) => {
119+
trigrams_to_delete.insert(encoded);
120+
}
121+
Err(TokenInstanceEncodeError::NullTerminator) => {
122+
// ignore
123+
}
124+
}
112125
});
113126

114127
// Add doc
115128
self.tokenizer
116129
.clone()
117130
.token_stream(new_document)
118131
.process(&mut |token| {
119-
trigrams_to_delete.remove(&TokenInstance::encode(
120-
token.text.as_str(),
121-
offset_id,
122-
None,
123-
));
124-
125-
token_instances.push(TokenInstance::encode(
132+
match TokenInstance::encode(token.text.as_str(), offset_id, None) {
133+
Ok(encoded) => {
134+
trigrams_to_delete.remove(&encoded);
135+
}
136+
Err(TokenInstanceEncodeError::NullTerminator) => {
137+
// ignore
138+
}
139+
}
140+
141+
match TokenInstance::encode(
126142
token.text.as_str(),
127143
offset_id,
128144
Some(token.offset_from as u32),
129-
));
145+
) {
146+
Ok(encoded) => {
147+
token_instances.push(encoded);
148+
}
149+
Err(TokenInstanceEncodeError::NullTerminator) => {
150+
// ignore
151+
}
152+
}
130153
});
131154

132155
token_instances.extend(trigrams_to_delete.into_iter());
@@ -143,11 +166,18 @@ impl FullTextIndexWriter {
143166
.clone()
144167
.token_stream(old_document)
145168
.process(&mut |token| {
146-
trigrams_to_delete.insert(TokenInstance::encode(
169+
match TokenInstance::encode(
147170
token.text.as_str(),
148171
offset_id,
149-
None,
150-
));
172+
Some(token.offset_from as u32),
173+
) {
174+
Ok(encoded) => {
175+
trigrams_to_delete.insert(encoded);
176+
}
177+
Err(TokenInstanceEncodeError::NullTerminator) => {
178+
// ignore
179+
}
180+
}
151181
});
152182

153183
token_instances.extend(trigrams_to_delete.into_iter());
@@ -909,6 +939,44 @@ mod tests {
909939
assert_eq!(res.len(), 3);
910940
}
911941

942+
#[tokio::test]
943+
async fn test_document_with_null_terminators() {
944+
let tmp_dir = tempdir().unwrap();
945+
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
946+
let block_cache = new_cache_for_test();
947+
let root_cache = new_cache_for_test();
948+
let provider = BlockfileProvider::new_arrow(storage, 1024 * 1024, block_cache, root_cache);
949+
let pl_blockfile_writer = provider
950+
.write::<u32, Vec<u32>>(BlockfileWriterOptions::default().ordered_mutations())
951+
.await
952+
.unwrap();
953+
let pl_blockfile_id = pl_blockfile_writer.id();
954+
955+
let tokenizer = NgramTokenizer::new(3, 3, false).unwrap();
956+
let mut index_writer = FullTextIndexWriter::new(pl_blockfile_writer, tokenizer.clone());
957+
958+
index_writer
959+
.handle_batch([DocumentMutation::Create {
960+
offset_id: 1,
961+
new_document: "hello \0 wor\0ld",
962+
}])
963+
.unwrap();
964+
965+
index_writer.write_to_blockfiles().await.unwrap();
966+
let flusher = index_writer.commit().await.unwrap();
967+
flusher.flush().await.unwrap();
968+
969+
let pl_blockfile_reader = provider
970+
.read::<u32, &[u32]>(&pl_blockfile_id)
971+
.await
972+
.unwrap();
973+
let tokenizer = NgramTokenizer::new(3, 3, false).unwrap();
974+
let index_reader = FullTextIndexReader::new(pl_blockfile_reader, tokenizer);
975+
976+
let res = index_reader.search("hello").await.unwrap();
977+
assert_eq!(res, RoaringBitmap::from([1]));
978+
}
979+
912980
#[tokio::test]
913981
async fn test_update_document() {
914982
let tmp_dir = tempdir().unwrap();

rust/index/src/fulltext/util.rs

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use thiserror::Error;
2+
13
/// A token instance is a unique value containing a trigram, an offset ID, and optionally a position within a document.
24
/// These three attributes are packed into a single u128 value:
35
/// - The trigram is a 63-bit value, packed into the top 64 bits.
@@ -8,15 +10,21 @@
810
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
911
pub struct TokenInstance(u128);
1012

11-
// Unicode characters only use 21 bits, so we can encode a trigram in 21 * 3 = 63 bits (a u64).
13+
/// Unicode characters only use 21 bits, so we can encode a trigram in 21 * 3 = 63 bits (a u64).
14+
/// Returns None if the string contains a null terminator.
1215
#[inline(always)]
13-
fn pack_trigram(s: &str) -> u64 {
16+
fn pack_trigram(s: &str) -> Option<u64> {
1417
let mut u = 0u64;
1518
for (i, c) in s.chars().take(3).enumerate() {
19+
if c == '\0' {
20+
return None;
21+
}
22+
1623
let shift = (2 - i) * 21;
1724
u |= (c as u64) << shift;
1825
}
19-
u
26+
27+
Some(u)
2028
}
2129

2230
fn encode_utf8_unchecked(c: u32, buf: &mut [u8]) -> usize {
@@ -74,16 +82,29 @@ fn unpack_trigram(u: u64) -> String {
7482
s
7583
}
7684

85+
#[derive(Debug, Error)]
86+
pub enum TokenInstanceEncodeError {
87+
#[error("Token contains null terminator")]
88+
NullTerminator,
89+
}
90+
7791
impl TokenInstance {
7892
pub const MAX: Self = Self(u128::MAX);
7993

8094
#[inline(always)]
81-
pub fn encode(token: &str, offset_id: u32, position: Option<u32>) -> Self {
82-
TokenInstance(
83-
((pack_trigram(token) as u128) << 64)
84-
| ((offset_id as u128) << 32)
85-
| (position.map(|o| o | (1 << 31)).unwrap_or(0) as u128),
86-
)
95+
pub fn encode(
96+
token: &str,
97+
offset_id: u32,
98+
position: Option<u32>,
99+
) -> Result<Self, TokenInstanceEncodeError> {
100+
match pack_trigram(token) {
101+
Some(packed) => Ok(TokenInstance(
102+
((packed as u128) << 64)
103+
| ((offset_id as u128) << 32)
104+
| (position.map(|o| o | (1 << 31)).unwrap_or(0) as u128),
105+
)),
106+
None => Err(TokenInstanceEncodeError::NullTerminator),
107+
}
87108
}
88109

89110
#[inline(always)]
@@ -121,7 +142,7 @@ mod tests {
121142
proptest! {
122143
#[test]
123144
fn test_pack_unpack_trigram(token in "\\PC{3}", offset_id in 0..u32::MAX, position in proptest::option::of((0..u32::MAX).prop_map(|v| v >> 1))) {
124-
let encoded = TokenInstance::encode(&token, offset_id, position);
145+
let encoded = TokenInstance::encode(&token, offset_id, position).unwrap();
125146
let decoded_token = encoded.get_token();
126147
let decoded_offset_id = encoded.get_offset_id();
127148
let decoded_position = encoded.get_position();
@@ -133,8 +154,8 @@ mod tests {
133154

134155
#[test]
135156
fn test_omit_position(token in "\\PC{3}", offset_id in 0..u32::MAX, position1 in proptest::option::of(0..u32::MAX), position2 in proptest::option::of(0..u32::MAX)) {
136-
let encoded1 = TokenInstance::encode(&token, offset_id, position1);
137-
let encoded2 = TokenInstance::encode(&token, offset_id, position2);
157+
let encoded1 = TokenInstance::encode(&token, offset_id, position1).unwrap();
158+
let encoded2 = TokenInstance::encode(&token, offset_id, position2).unwrap();
138159

139160
assert_eq!(encoded1.omit_position(), encoded2.omit_position(), "Omitting position should make two token instances equal");
140161
assert_eq!(encoded1.omit_position().get_token(), encoded1.get_token(), "Omitting position should not change the token");

0 commit comments

Comments
 (0)