Skip to content

Commit dc975f7

Browse files
[ENH] Add python & js client support to query on subset of IDs (chroma-core#4250)
2 parents 95c0748 + 827f0f9 commit dc975f7

File tree

22 files changed

+582
-16
lines changed

22 files changed

+582
-16
lines changed

chromadb/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _query(
270270
self,
271271
collection_id: UUID,
272272
query_embeddings: Embeddings,
273+
ids: Optional[IDs] = None,
273274
n_results: int = 10,
274275
where: Optional[Where] = None,
275276
where_document: Optional[WhereDocument] = None,
@@ -280,6 +281,7 @@ def _query(
280281
Args:
281282
collection_id: The UUID of the collection to query.
282283
query_embeddings: The embeddings to use as the query.
284+
ids: The IDs to filter by during the query. Defaults to None.
283285
n_results: The number of results to return. Defaults to 10.
284286
where: Conditional filtering on metadata. Defaults to None.
285287
where_document: Conditional filtering on documents. Defaults to None.
@@ -734,6 +736,7 @@ def _query(
734736
self,
735737
collection_id: UUID,
736738
query_embeddings: Embeddings,
739+
ids: Optional[IDs] = None,
737740
n_results: int = 10,
738741
where: Optional[Where] = None,
739742
where_document: Optional[WhereDocument] = None,

chromadb/api/async_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ async def _query(
264264
self,
265265
collection_id: UUID,
266266
query_embeddings: Embeddings,
267+
ids: Optional[IDs] = None,
267268
n_results: int = 10,
268269
where: Optional[Where] = None,
269270
where_document: Optional[WhereDocument] = None,
@@ -728,6 +729,7 @@ async def _query(
728729
self,
729730
collection_id: UUID,
730731
query_embeddings: Embeddings,
732+
ids: Optional[IDs] = None,
731733
n_results: int = 10,
732734
where: Optional[Where] = None,
733735
where_document: Optional[WhereDocument] = None,

chromadb/api/async_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ async def _query(
403403
self,
404404
collection_id: UUID,
405405
query_embeddings: Embeddings,
406+
ids: Optional[IDs] = None,
406407
n_results: int = 10,
407408
where: Optional[Where] = None,
408409
where_document: Optional[WhereDocument] = None,
@@ -411,6 +412,7 @@ async def _query(
411412
return await self._server._query(
412413
collection_id=collection_id,
413414
query_embeddings=query_embeddings,
415+
ids=ids,
414416
n_results=n_results,
415417
where=where,
416418
where_document=where_document,

chromadb/api/async_fastapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ async def _query(
617617
self,
618618
collection_id: UUID,
619619
query_embeddings: Embeddings,
620+
ids: Optional[IDs] = None,
620621
n_results: int = 10,
621622
where: Optional[Where] = None,
622623
where_document: Optional[WhereDocument] = None,
@@ -631,6 +632,7 @@ async def _query(
631632
"post",
632633
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
633634
json={
635+
"ids": ids,
634636
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
635637
if query_embeddings is not None
636638
else None,

chromadb/api/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ def _query(
374374
self,
375375
collection_id: UUID,
376376
query_embeddings: Embeddings,
377+
ids: Optional[IDs] = None,
377378
n_results: int = 10,
378379
where: Optional[Where] = None,
379380
where_document: Optional[WhereDocument] = None,
380381
include: Include = IncludeMetadataDocumentsDistances,
381382
) -> QueryResult:
382383
return self._server._query(
383384
collection_id=collection_id,
385+
ids=ids,
384386
tenant=self.tenant,
385387
database=self.database,
386388
query_embeddings=query_embeddings,

chromadb/api/fastapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def _query(
588588
self,
589589
collection_id: UUID,
590590
query_embeddings: Embeddings,
591+
ids: Optional[IDs] = None,
591592
n_results: int = 10,
592593
where: Optional[Where] = None,
593594
where_document: Optional[WhereDocument] = None,
@@ -603,6 +604,7 @@ def _query(
603604
"post",
604605
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
605606
json={
607+
"ids": ids,
606608
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
607609
if query_embeddings is not None
608610
else None,

chromadb/api/models/AsyncCollection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ async def query(
169169
query_texts: Optional[OneOrMany[Document]] = None,
170170
query_images: Optional[OneOrMany[Image]] = None,
171171
query_uris: Optional[OneOrMany[URI]] = None,
172+
ids: Optional[OneOrMany[ID]] = None,
172173
n_results: int = 10,
173174
where: Optional[Where] = None,
174175
where_document: Optional[WhereDocument] = None,
@@ -184,6 +185,7 @@ async def query(
184185
query_embeddings: The embeddings to get the closes neighbors of. Optional.
185186
query_texts: The document texts to get the closes neighbors of. Optional.
186187
query_images: The images to get the closes neighbors of. Optional.
188+
ids: A subset of ids to search within. Optional.
187189
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
188190
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
189191
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -205,6 +207,7 @@ async def query(
205207
query_texts=query_texts,
206208
query_images=query_images,
207209
query_uris=query_uris,
210+
ids=ids,
208211
n_results=n_results,
209212
where=where,
210213
where_document=where_document,
@@ -213,6 +216,7 @@ async def query(
213216

214217
query_results = await self._client._query(
215218
collection_id=self.id,
219+
ids=query_request["ids"],
216220
query_embeddings=query_request["embeddings"],
217221
n_results=query_request["n_results"],
218222
where=query_request["where"],
@@ -279,7 +283,7 @@ async def fork(
279283
client=self._client,
280284
model=model,
281285
embedding_function=self._embedding_function,
282-
data_loader=self._data_loader
286+
data_loader=self._data_loader,
283287
)
284288

285289
async def update(

chromadb/api/models/Collection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def query(
172172
query_texts: Optional[OneOrMany[Document]] = None,
173173
query_images: Optional[OneOrMany[Image]] = None,
174174
query_uris: Optional[OneOrMany[URI]] = None,
175+
ids: Optional[OneOrMany[ID]] = None,
175176
n_results: int = 10,
176177
where: Optional[Where] = None,
177178
where_document: Optional[WhereDocument] = None,
@@ -188,6 +189,7 @@ def query(
188189
query_texts: The document texts to get the closes neighbors of. Optional.
189190
query_images: The images to get the closes neighbors of. Optional.
190191
query_uris: The URIs to be used with data loader. Optional.
192+
ids: A subset of ids to search within. Optional.
191193
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
192194
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
193195
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -209,6 +211,7 @@ def query(
209211
query_texts=query_texts,
210212
query_images=query_images,
211213
query_uris=query_uris,
214+
ids=ids,
212215
n_results=n_results,
213216
where=where,
214217
where_document=where_document,
@@ -217,6 +220,7 @@ def query(
217220

218221
query_results = self._client._query(
219222
collection_id=self.id,
223+
ids=query_request["ids"],
220224
query_embeddings=query_request["embeddings"],
221225
n_results=query_request["n_results"],
222226
where=query_request["where"],
@@ -285,7 +289,7 @@ def fork(
285289
client=self._client,
286290
model=model,
287291
embedding_function=self._embedding_function,
288-
data_loader=self._data_loader
292+
data_loader=self._data_loader,
289293
)
290294

291295
def update(

chromadb/api/models/CollectionCommon.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def _validate_and_prepare_query_request(
294294
query_texts: Optional[OneOrMany[Document]],
295295
query_images: Optional[OneOrMany[Image]],
296296
query_uris: Optional[OneOrMany[URI]],
297+
ids: Optional[OneOrMany[ID]],
297298
n_results: int,
298299
where: Optional[Where],
299300
where_document: Optional[WhereDocument],
@@ -307,6 +308,8 @@ def _validate_and_prepare_query_request(
307308
uris=query_uris,
308309
)
309310

311+
filter_ids = maybe_cast_one_to_many(ids)
312+
310313
filters = FilterSet(
311314
where=where,
312315
where_document=where_document,
@@ -335,6 +338,7 @@ def _validate_and_prepare_query_request(
335338

336339
return QueryRequest(
337340
embeddings=request_embeddings,
341+
ids=filter_ids,
338342
where=request_where,
339343
where_document=request_where_document,
340344
include=request_include,

chromadb/api/rust.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def _query(
482482
self,
483483
collection_id: UUID,
484484
query_embeddings: Embeddings,
485+
ids: Optional[IDs] = None,
485486
n_results: int = 10,
486487
where: Optional[Where] = None,
487488
where_document: Optional[WhereDocument] = None,
@@ -490,10 +491,12 @@ def _query(
490491
database: str = DEFAULT_DATABASE,
491492
) -> QueryResult:
492493
query_amount = len(query_embeddings)
494+
filtered_ids_amount = len(ids) if ids else 0
493495
self.product_telemetry_client.capture(
494496
CollectionQueryEvent(
495497
collection_uuid=str(collection_id),
496498
query_amount=query_amount,
499+
filtered_ids_amount=filtered_ids_amount,
497500
n_results=n_results,
498501
with_metadata_filter=query_amount if where is not None else 0,
499502
with_document_filter=query_amount if where_document is not None else 0,
@@ -506,6 +509,7 @@ def _query(
506509

507510
rust_response = self.bindings.query(
508511
str(collection_id),
512+
ids,
509513
query_embeddings,
510514
n_results,
511515
json.dumps(where) if where else None,

chromadb/api/segment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ def _query(
785785
self,
786786
collection_id: UUID,
787787
query_embeddings: Embeddings,
788+
ids: Optional[IDs] = None,
788789
n_results: int = 10,
789790
where: Optional[Where] = None,
790791
where_document: Optional[WhereDocument] = None,
@@ -801,10 +802,12 @@ def _query(
801802
)
802803

803804
query_amount = len(query_embeddings)
805+
ids_amount = len(ids) if ids else 0
804806
self._product_telemetry_client.capture(
805807
CollectionQueryEvent(
806808
collection_uuid=str(collection_id),
807809
query_amount=query_amount,
810+
filtered_ids_amount=ids_amount,
808811
n_results=n_results,
809812
with_metadata_filter=query_amount if where is not None else 0,
810813
with_document_filter=query_amount if where_document is not None else 0,

chromadb/api/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ class GetResult(TypedDict):
396396

397397
class QueryRequest(TypedDict):
398398
embeddings: Embeddings
399+
ids: Optional[IDs]
399400
where: Optional[Where]
400401
where_document: Optional[WhereDocument]
401402
include: Include

chromadb/telemetry/product/events.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CollectionQueryEvent(ProductTelemetryEvent):
137137
batch_size: int
138138
collection_uuid: str
139139
query_amount: int
140+
filtered_ids_amount: int
140141
with_metadata_filter: int
141142
with_document_filter: int
142143
n_results: int
@@ -149,6 +150,7 @@ def __init__(
149150
self,
150151
collection_uuid: str,
151152
query_amount: int,
153+
filtered_ids_amount: int,
152154
with_metadata_filter: int,
153155
with_document_filter: int,
154156
n_results: int,
@@ -161,6 +163,7 @@ def __init__(
161163
super().__init__()
162164
self.collection_uuid = collection_uuid
163165
self.query_amount = query_amount
166+
self.filtered_ids_amount = filtered_ids_amount
164167
self.with_metadata_filter = with_metadata_filter
165168
self.with_document_filter = with_document_filter
166169
self.n_results = n_results
@@ -182,6 +185,7 @@ def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent":
182185
return CollectionQueryEvent(
183186
collection_uuid=self.collection_uuid,
184187
query_amount=total_amount,
188+
filtered_ids_amount=self.filtered_ids_amount + other.filtered_ids_amount,
185189
with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter,
186190
with_document_filter=self.with_document_filter + other.with_document_filter,
187191
n_results=self.n_results + other.n_results,

0 commit comments

Comments
 (0)