Skip to content

Commit 4a0795f

Browse files
committed
[ENH] Add python client support to query on subset of IDs
1 parent 2802145 commit 4a0795f

File tree

12 files changed

+269
-8
lines changed

12 files changed

+269
-8
lines changed

chromadb/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _query(
270270
self,
271271
collection_id: UUID,
272272
query_embeddings: Embeddings,
273+
ids: Optional[IDs] = None,
273274
n_results: int = 10,
274275
where: Optional[Where] = None,
275276
where_document: Optional[WhereDocument] = None,
@@ -724,6 +725,7 @@ def _query(
724725
self,
725726
collection_id: UUID,
726727
query_embeddings: Embeddings,
728+
ids: Optional[IDs] = None,
727729
n_results: int = 10,
728730
where: Optional[Where] = None,
729731
where_document: Optional[WhereDocument] = None,

chromadb/api/async_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ async def _query(
264264
self,
265265
collection_id: UUID,
266266
query_embeddings: Embeddings,
267+
ids: Optional[IDs] = None,
267268
n_results: int = 10,
268269
where: Optional[Where] = None,
269270
where_document: Optional[WhereDocument] = None,
@@ -718,6 +719,7 @@ async def _query(
718719
self,
719720
collection_id: UUID,
720721
query_embeddings: Embeddings,
722+
ids: Optional[IDs] = None,
721723
n_results: int = 10,
722724
where: Optional[Where] = None,
723725
where_document: Optional[WhereDocument] = None,

chromadb/api/async_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ async def _query(
403403
self,
404404
collection_id: UUID,
405405
query_embeddings: Embeddings,
406+
ids: Optional[IDs] = None,
406407
n_results: int = 10,
407408
where: Optional[Where] = None,
408409
where_document: Optional[WhereDocument] = None,
@@ -411,6 +412,7 @@ async def _query(
411412
return await self._server._query(
412413
collection_id=collection_id,
413414
query_embeddings=query_embeddings,
415+
ids=ids,
414416
n_results=n_results,
415417
where=where,
416418
where_document=where_document,

chromadb/api/async_fastapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ async def _query(
623623
self,
624624
collection_id: UUID,
625625
query_embeddings: Embeddings,
626+
ids: Optional[IDs] = None,
626627
n_results: int = 10,
627628
where: Optional[Where] = None,
628629
where_document: Optional[WhereDocument] = None,

chromadb/api/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ def _query(
374374
self,
375375
collection_id: UUID,
376376
query_embeddings: Embeddings,
377+
ids: Optional[IDs] = None,
377378
n_results: int = 10,
378379
where: Optional[Where] = None,
379380
where_document: Optional[WhereDocument] = None,
380381
include: Include = IncludeMetadataDocumentsDistances,
381382
) -> QueryResult:
382383
return self._server._query(
383384
collection_id=collection_id,
385+
ids=ids,
384386
tenant=self.tenant,
385387
database=self.database,
386388
query_embeddings=query_embeddings,

chromadb/api/fastapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _query(
591591
self,
592592
collection_id: UUID,
593593
query_embeddings: Embeddings,
594+
ids: Optional[IDs] = None,
594595
n_results: int = 10,
595596
where: Optional[Where] = None,
596597
where_document: Optional[WhereDocument] = None,
@@ -606,6 +607,7 @@ def _query(
606607
"post",
607608
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
608609
json={
610+
"ids": ids,
609611
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
610612
if query_embeddings is not None
611613
else None,

chromadb/api/models/AsyncCollection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ async def query(
169169
query_texts: Optional[OneOrMany[Document]] = None,
170170
query_images: Optional[OneOrMany[Image]] = None,
171171
query_uris: Optional[OneOrMany[URI]] = None,
172+
ids: Optional[IDs] = None,
172173
n_results: int = 10,
173174
where: Optional[Where] = None,
174175
where_document: Optional[WhereDocument] = None,
@@ -184,6 +185,7 @@ async def query(
184185
query_embeddings: The embeddings to get the closes neighbors of. Optional.
185186
query_texts: The document texts to get the closes neighbors of. Optional.
186187
query_images: The images to get the closes neighbors of. Optional.
188+
ids: A subset of ids to search within. Optional.
187189
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
188190
where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
189191
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
@@ -213,6 +215,7 @@ async def query(
213215

214216
query_results = await self._client._query(
215217
collection_id=self.id,
218+
ids=ids,
216219
query_embeddings=query_request["embeddings"],
217220
n_results=query_request["n_results"],
218221
where=query_request["where"],

chromadb/api/models/Collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def query(
172172
query_texts: Optional[OneOrMany[Document]] = None,
173173
query_images: Optional[OneOrMany[Image]] = None,
174174
query_uris: Optional[OneOrMany[URI]] = None,
175+
ids: Optional[IDs] = None,
175176
n_results: int = 10,
176177
where: Optional[Where] = None,
177178
where_document: Optional[WhereDocument] = None,
@@ -188,6 +189,7 @@ def query(
188189
query_texts: The document texts to get the closes neighbors of. Optional.
189190
query_images: The images to get the closes neighbors of. Optional.
190191
query_uris: The URIs to be used with data loader. Optional.
192+
ids: A subset of ids to search within. Optional.
191193
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
192194
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
193195
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
@@ -217,6 +219,7 @@ def query(
217219

218220
query_results = self._client._query(
219221
collection_id=self.id,
222+
ids=ids,
220223
query_embeddings=query_request["embeddings"],
221224
n_results=query_request["n_results"],
222225
where=query_request["where"],

chromadb/api/rust.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def _query(
472472
self,
473473
collection_id: UUID,
474474
query_embeddings: Embeddings,
475+
ids: Optional[IDs] = None,
475476
n_results: int = 10,
476477
where: Optional[Where] = None,
477478
where_document: Optional[WhereDocument] = None,
@@ -496,6 +497,7 @@ def _query(
496497

497498
rust_response = self.bindings.query(
498499
str(collection_id),
500+
ids,
499501
query_embeddings,
500502
n_results,
501503
json.dumps(where) if where else None,

chromadb/api/segment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ def _query(
775775
self,
776776
collection_id: UUID,
777777
query_embeddings: Embeddings,
778+
ids: Optional[IDs] = None,
778779
n_results: int = 10,
779780
where: Optional[Where] = None,
780781
where_document: Optional[WhereDocument] = None,

0 commit comments

Comments
 (0)