Skip to content

Commit cac7f29

Browse files
committed
[ENH] Add python client support to query on subset of IDs
1 parent 7a3f562 commit cac7f29

File tree

13 files changed

+275
-11
lines changed

13 files changed

+275
-11
lines changed

chromadb/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _query(
270270
self,
271271
collection_id: UUID,
272272
query_embeddings: Embeddings,
273+
ids: Optional[IDs] = None,
273274
n_results: int = 10,
274275
where: Optional[Where] = None,
275276
where_document: Optional[WhereDocument] = None,
@@ -280,6 +281,7 @@ def _query(
280281
Args:
281282
collection_id: The UUID of the collection to query.
282283
query_embeddings: The embeddings to use as the query.
284+
ids: The IDs to filter by during the query. Defaults to None.
283285
n_results: The number of results to return. Defaults to 10.
284286
where: Conditional filtering on metadata. Defaults to {}.
285287
where_document: Conditional filtering on documents. Defaults to {}.
@@ -724,6 +726,7 @@ def _query(
724726
self,
725727
collection_id: UUID,
726728
query_embeddings: Embeddings,
729+
ids: Optional[IDs] = None,
727730
n_results: int = 10,
728731
where: Optional[Where] = None,
729732
where_document: Optional[WhereDocument] = None,

chromadb/api/async_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ async def _query(
264264
self,
265265
collection_id: UUID,
266266
query_embeddings: Embeddings,
267+
ids: Optional[IDs] = None,
267268
n_results: int = 10,
268269
where: Optional[Where] = None,
269270
where_document: Optional[WhereDocument] = None,
@@ -718,6 +719,7 @@ async def _query(
718719
self,
719720
collection_id: UUID,
720721
query_embeddings: Embeddings,
722+
ids: Optional[IDs] = None,
721723
n_results: int = 10,
722724
where: Optional[Where] = None,
723725
where_document: Optional[WhereDocument] = None,

chromadb/api/async_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ async def _query(
403403
self,
404404
collection_id: UUID,
405405
query_embeddings: Embeddings,
406+
ids: Optional[IDs] = None,
406407
n_results: int = 10,
407408
where: Optional[Where] = None,
408409
where_document: Optional[WhereDocument] = None,
@@ -411,6 +412,7 @@ async def _query(
411412
return await self._server._query(
412413
collection_id=collection_id,
413414
query_embeddings=query_embeddings,
415+
ids=ids,
414416
n_results=n_results,
415417
where=where,
416418
where_document=where_document,

chromadb/api/async_fastapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ async def _query(
623623
self,
624624
collection_id: UUID,
625625
query_embeddings: Embeddings,
626+
ids: Optional[IDs] = None,
626627
n_results: int = 10,
627628
where: Optional[Where] = None,
628629
where_document: Optional[WhereDocument] = None,

chromadb/api/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ def _query(
374374
self,
375375
collection_id: UUID,
376376
query_embeddings: Embeddings,
377+
ids: Optional[IDs] = None,
377378
n_results: int = 10,
378379
where: Optional[Where] = None,
379380
where_document: Optional[WhereDocument] = None,
380381
include: Include = IncludeMetadataDocumentsDistances,
381382
) -> QueryResult:
382383
return self._server._query(
383384
collection_id=collection_id,
385+
ids=ids,
384386
tenant=self.tenant,
385387
database=self.database,
386388
query_embeddings=query_embeddings,

chromadb/api/fastapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _query(
591591
self,
592592
collection_id: UUID,
593593
query_embeddings: Embeddings,
594+
ids: Optional[IDs] = None,
594595
n_results: int = 10,
595596
where: Optional[Where] = None,
596597
where_document: Optional[WhereDocument] = None,
@@ -606,6 +607,7 @@ def _query(
606607
"post",
607608
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
608609
json={
610+
"ids": ids,
609611
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
610612
if query_embeddings is not None
611613
else None,

chromadb/api/models/AsyncCollection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ async def query(
169169
query_texts: Optional[OneOrMany[Document]] = None,
170170
query_images: Optional[OneOrMany[Image]] = None,
171171
query_uris: Optional[OneOrMany[URI]] = None,
172+
ids: Optional[IDs] = None,
172173
n_results: int = 10,
173174
where: Optional[Where] = None,
174175
where_document: Optional[WhereDocument] = None,
@@ -184,6 +185,7 @@ async def query(
184185
query_embeddings: The embeddings to get the closes neighbors of. Optional.
185186
query_texts: The document texts to get the closes neighbors of. Optional.
186187
query_images: The images to get the closes neighbors of. Optional.
188+
ids: A subset of ids to search within. Optional.
187189
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
188190
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
189191
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -213,6 +215,7 @@ async def query(
213215

214216
query_results = await self._client._query(
215217
collection_id=self.id,
218+
ids=ids,
216219
query_embeddings=query_request["embeddings"],
217220
n_results=query_request["n_results"],
218221
where=query_request["where"],

chromadb/api/models/Collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def query(
172172
query_texts: Optional[OneOrMany[Document]] = None,
173173
query_images: Optional[OneOrMany[Image]] = None,
174174
query_uris: Optional[OneOrMany[URI]] = None,
175+
ids: Optional[IDs] = None,
175176
n_results: int = 10,
176177
where: Optional[Where] = None,
177178
where_document: Optional[WhereDocument] = None,
@@ -188,6 +189,7 @@ def query(
188189
query_texts: The document texts to get the closes neighbors of. Optional.
189190
query_images: The images to get the closes neighbors of. Optional.
190191
query_uris: The URIs to be used with data loader. Optional.
192+
ids: A subset of ids to search within. Optional.
191193
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
192194
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
193195
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -217,6 +219,7 @@ def query(
217219

218220
query_results = self._client._query(
219221
collection_id=self.id,
222+
ids=ids,
220223
query_embeddings=query_request["embeddings"],
221224
n_results=query_request["n_results"],
222225
where=query_request["where"],

chromadb/api/rust.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def _query(
472472
self,
473473
collection_id: UUID,
474474
query_embeddings: Embeddings,
475+
ids: Optional[IDs] = None,
475476
n_results: int = 10,
476477
where: Optional[Where] = None,
477478
where_document: Optional[WhereDocument] = None,
@@ -480,10 +481,12 @@ def _query(
480481
database: str = DEFAULT_DATABASE,
481482
) -> QueryResult:
482483
query_amount = len(query_embeddings)
484+
filtered_ids_amount = len(ids) if ids else 0
483485
self.product_telemetry_client.capture(
484486
CollectionQueryEvent(
485487
collection_uuid=str(collection_id),
486488
query_amount=query_amount,
489+
filtered_ids_amount=filtered_ids_amount,
487490
n_results=n_results,
488491
with_metadata_filter=query_amount if where is not None else 0,
489492
with_document_filter=query_amount if where_document is not None else 0,
@@ -496,6 +499,7 @@ def _query(
496499

497500
rust_response = self.bindings.query(
498501
str(collection_id),
502+
ids,
499503
query_embeddings,
500504
n_results,
501505
json.dumps(where) if where else None,

chromadb/api/segment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ def _query(
775775
self,
776776
collection_id: UUID,
777777
query_embeddings: Embeddings,
778+
ids: Optional[IDs] = None,
778779
n_results: int = 10,
779780
where: Optional[Where] = None,
780781
where_document: Optional[WhereDocument] = None,

chromadb/telemetry/product/events.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CollectionQueryEvent(ProductTelemetryEvent):
137137
batch_size: int
138138
collection_uuid: str
139139
query_amount: int
140+
filtered_ids_amount: int
140141
with_metadata_filter: int
141142
with_document_filter: int
142143
n_results: int
@@ -149,6 +150,7 @@ def __init__(
149150
self,
150151
collection_uuid: str,
151152
query_amount: int,
153+
filtered_ids_amount: int,
152154
with_metadata_filter: int,
153155
with_document_filter: int,
154156
n_results: int,
@@ -161,6 +163,7 @@ def __init__(
161163
super().__init__()
162164
self.collection_uuid = collection_uuid
163165
self.query_amount = query_amount
166+
self.filtered_ids_amount = filtered_ids_amount
164167
self.with_metadata_filter = with_metadata_filter
165168
self.with_document_filter = with_document_filter
166169
self.n_results = n_results
@@ -182,6 +185,7 @@ def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent":
182185
return CollectionQueryEvent(
183186
collection_uuid=self.collection_uuid,
184187
query_amount=total_amount,
188+
filtered_ids_amount=self.filtered_ids_amount + other.filtered_ids_amount,
185189
with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter,
186190
with_document_filter=self.with_document_filter + other.with_document_filter,
187191
n_results=self.n_results + other.n_results,

0 commit comments

Comments
 (0)