Skip to content

Commit 0c1fc3b

Browse files
authored
[ENH] Implement literal provider (#4430)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Implement the literal provider interface and default impls, which allows evaluation of literal expression to a set of documents. ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_ <!-- Summary by @propel-code-bot --> --- This PR introduces a new NgramLiteralProvider trait and implementation that enables efficient evaluation of literal expressions against document collections. The implementation uses n-gram based matching to find documents containing specific literal patterns, with support for handling compound expressions (concatenation and alternation). The code includes optimizations for different literal types and large branching factors. **Key Changes:** • Added NgramLiteralProvider trait with methods for literal expression evaluation • Implemented match_literal_with_mask for efficient n-gram based document matching • Added support for evaluating compound literal expressions (concat, alternation) • Enhanced Literal struct with width() method for optimization **Affected Areas:** • rust/types/src/regex/literal_expr.rs • Dependencies (added async-trait) *This summary was automatically generated by @propel-code-bot*
1 parent 1be97d1 commit 0c1fc3b

File tree

3 files changed

+230
-2
lines changed

3 files changed

+230
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/types/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = "2021"
77
path = "src/lib.rs"
88

99
[dependencies]
10+
async-trait = { workspace = true }
1011
prost = { workspace = true }
1112
prost-types = { workspace = true }
1213
roaring = { workspace = true }

rust/types/src/regex/literal_expr.rs

Lines changed: 228 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
use std::{collections::HashMap, ops::RangeBounds};
2+
13
use regex_syntax::hir::ClassUnicode;
4+
use roaring::RoaringBitmap;
25

36
use super::hir::ChromaHir;
47

@@ -8,6 +11,15 @@ pub enum Literal {
811
Class(ClassUnicode),
912
}
1013

14+
impl Literal {
15+
pub fn width(&self) -> usize {
16+
match self {
17+
Literal::Char(_) => 1,
18+
Literal::Class(class_unicode) => class_unicode.iter().map(|range| range.len()).sum(),
19+
}
20+
}
21+
}
22+
1123
#[derive(Clone, Debug)]
1224
pub enum LiteralExpr {
1325
Literal(Vec<Literal>),
@@ -30,7 +42,7 @@ impl From<ChromaHir> for LiteralExpr {
3042
ChromaHir::Concat(repeat).into()
3143
}
3244
ChromaHir::Concat(hirs) => {
33-
let exprs = hirs.into_iter().fold(Vec::new(), |mut exprs, expr| {
45+
let mut exprs = hirs.into_iter().fold(Vec::new(), |mut exprs, expr| {
3446
match (exprs.last_mut(), expr.into()) {
3547
(Some(Self::Literal(literal)), Self::Literal(extra_literal)) => {
3648
literal.extend(extra_literal)
@@ -39,11 +51,225 @@ impl From<ChromaHir> for LiteralExpr {
3951
}
4052
exprs
4153
});
42-
Self::Concat(exprs)
54+
if exprs.len() > 1 {
55+
Self::Concat(exprs)
56+
} else if let Some(expr) = exprs.pop() {
57+
expr
58+
} else {
59+
Self::Literal(Vec::new())
60+
}
4361
}
4462
ChromaHir::Alternation(hirs) => {
4563
Self::Alternation(hirs.into_iter().map(Into::into).collect())
4664
}
4765
}
4866
}
4967
}
68+
69+
#[async_trait::async_trait]
70+
pub trait NgramLiteralProvider<E, const N: usize = 3> {
71+
// Return the max branching factor during the search
72+
fn maximum_branching_factor(&self) -> usize;
73+
74+
// Return the (ngram, doc_id, positions) for a range of ngrams
75+
async fn lookup_ngram_range<'me, NgramRange>(
76+
&'me self,
77+
ngram_range: NgramRange,
78+
) -> Result<Vec<(&'me str, u32, RoaringBitmap)>, E>
79+
where
80+
NgramRange: Clone + RangeBounds<&'me str> + Send + Sync;
81+
82+
// Return the documents containing the literals. The search space is restricted to the documents in the mask if specified
83+
// If all documents could contain the literals, Ok(None) is returned
84+
async fn match_literal_with_mask(
85+
&self,
86+
literals: &[Literal],
87+
mask: Option<&RoaringBitmap>,
88+
) -> Result<Option<RoaringBitmap>, E> {
89+
if mask.is_some_and(|m| m.is_empty()) {
90+
return Ok(mask.cloned());
91+
}
92+
93+
let (initial_literals, remaining_literals) = literals.split_at(N);
94+
let initial_ngrams =
95+
initial_literals
96+
.iter()
97+
.fold(vec![Vec::with_capacity(N)], |mut acc, lit| match lit {
98+
Literal::Char(c) => {
99+
acc.iter_mut().for_each(|s| s.push(*c));
100+
acc
101+
}
102+
Literal::Class(class_unicode) => {
103+
acc.into_iter()
104+
.flat_map(|s| {
105+
class_unicode.iter().flat_map(|r| r.start()..=r.end()).map(
106+
move |c| {
107+
let mut sc = s.clone();
108+
sc.push(c);
109+
sc
110+
},
111+
)
112+
})
113+
.collect()
114+
}
115+
});
116+
117+
// ngram suffix -> doc_id -> position
118+
let mut suffix_doc_pos: HashMap<Vec<char>, HashMap<u32, RoaringBitmap>> = HashMap::new();
119+
for ngram in initial_ngrams {
120+
let ngram_string = ngram.iter().collect::<String>();
121+
let ngram_doc_pos = self
122+
.lookup_ngram_range(ngram_string.as_str()..=ngram_string.as_str())
123+
.await?;
124+
125+
if ngram_doc_pos.is_empty() {
126+
continue;
127+
}
128+
129+
let suffix = ngram[1..].to_vec();
130+
for (_, doc_id, pos) in ngram_doc_pos {
131+
if mask.map(|m| m.contains(doc_id)).unwrap_or(mask.is_none()) {
132+
*suffix_doc_pos
133+
.entry(suffix.clone())
134+
.or_default()
135+
.entry(doc_id)
136+
.or_default() |= pos;
137+
}
138+
}
139+
}
140+
141+
for literal in remaining_literals {
142+
if suffix_doc_pos.is_empty() {
143+
break;
144+
}
145+
let mut new_suffix_doc_pos: HashMap<Vec<char>, HashMap<u32, RoaringBitmap>> =
146+
HashMap::new();
147+
for (mut suffix, doc_pos) in suffix_doc_pos {
148+
let ngram_ranges = match literal {
149+
Literal::Char(literal_char) => {
150+
suffix.push(*literal_char);
151+
vec![(suffix.clone(), suffix)]
152+
}
153+
Literal::Class(class_unicode) => class_unicode
154+
.iter()
155+
.map(|r| {
156+
let mut min_ngram = suffix.clone();
157+
min_ngram.push(r.start());
158+
let mut max_ngram = suffix.clone();
159+
max_ngram.push(r.end());
160+
(min_ngram, max_ngram)
161+
})
162+
.collect(),
163+
};
164+
165+
for (min_ngram, max_ngram) in ngram_ranges {
166+
let min_ngram_string = min_ngram.iter().collect::<String>();
167+
let max_ngram_string = max_ngram.iter().collect::<String>();
168+
let ngram_doc_pos = self
169+
.lookup_ngram_range(min_ngram_string.as_str()..=max_ngram_string.as_str())
170+
.await?;
171+
for (ngram, doc_id, next_pos) in ngram_doc_pos {
172+
if let Some(pos) = doc_pos.get(&doc_id) {
173+
if ngram.chars().last().is_some_and(|c| match literal {
174+
Literal::Char(literal_char) => c == *literal_char,
175+
Literal::Class(class_unicode) => {
176+
class_unicode.iter().any(|r| r.start() <= c && c <= r.end())
177+
}
178+
}) {
179+
// SAFETY(Sicheng): The RoaringBitmap iterator should be sorted
180+
let valid_next_pos = RoaringBitmap::from_sorted_iter(
181+
pos.iter()
182+
.filter_map(|p| next_pos.contains(p + 1).then_some(p + 1)),
183+
)
184+
.expect("RoaringBitmap iterator should be sorted");
185+
186+
if !valid_next_pos.is_empty() {
187+
let new_suffix = ngram.chars().skip(1).collect();
188+
*new_suffix_doc_pos
189+
.entry(new_suffix)
190+
.or_default()
191+
.entry(doc_id)
192+
.or_default() |= valid_next_pos;
193+
}
194+
}
195+
}
196+
}
197+
}
198+
}
199+
suffix_doc_pos = new_suffix_doc_pos;
200+
}
201+
202+
let result = suffix_doc_pos
203+
.into_values()
204+
.flat_map(|doc_pos| doc_pos.into_keys())
205+
.collect();
206+
Ok(Some(result))
207+
}
208+
209+
// Return the documents matching the literal expression. The search space is restricted to the documents in the mask if specified
210+
// If all documents could match the literal expression, Ok(None) is returned
211+
async fn match_literal_expression_with_mask(
212+
&self,
213+
literal_expression: &LiteralExpr,
214+
mask: Option<&RoaringBitmap>,
215+
) -> Result<Option<RoaringBitmap>, E> {
216+
match literal_expression {
217+
LiteralExpr::Literal(literals) => {
218+
let mut result = mask.cloned();
219+
for query in literals.split(|lit| lit.width() > self.maximum_branching_factor()) {
220+
if result.as_ref().is_some_and(|m| m.is_empty()) {
221+
break;
222+
}
223+
if query.len() >= N {
224+
result = self.match_literal_with_mask(query, result.as_ref()).await?;
225+
}
226+
}
227+
Ok(result)
228+
}
229+
LiteralExpr::Concat(literal_exprs) => {
230+
let mut result = mask.cloned();
231+
for expr in literal_exprs {
232+
if result.as_ref().is_some_and(|m| m.is_empty()) {
233+
break;
234+
}
235+
result = self
236+
.match_literal_expression_with_mask(expr, result.as_ref())
237+
.await?;
238+
}
239+
Ok(result)
240+
}
241+
LiteralExpr::Alternation(literal_exprs) => {
242+
let mut result = RoaringBitmap::new();
243+
for expr in literal_exprs {
244+
if let Some(matching_docs) =
245+
self.match_literal_expression_with_mask(expr, mask).await?
246+
{
247+
result |= matching_docs;
248+
} else {
249+
return Ok(mask.cloned());
250+
}
251+
}
252+
Ok(Some(result))
253+
}
254+
}
255+
}
256+
257+
// Return the documents matching the literal expression
258+
// If all documents could match the literal expression, Ok(None) is returned
259+
async fn match_literal_expression(
260+
&self,
261+
literal_expression: &LiteralExpr,
262+
) -> Result<Option<RoaringBitmap>, E> {
263+
self.match_literal_expression_with_mask(literal_expression, None)
264+
.await
265+
}
266+
267+
fn can_match_exactly(&self, literal_expression: &LiteralExpr) -> bool {
268+
match literal_expression {
269+
LiteralExpr::Literal(literals) => literals
270+
.iter()
271+
.all(|c| c.width() <= self.maximum_branching_factor()),
272+
LiteralExpr::Concat(_) | LiteralExpr::Alternation(_) => false,
273+
}
274+
}
275+
}

0 commit comments

Comments
 (0)