Skip to content

Commit 99a1335

Browse files
committed
[ENH] Add python client support to query on subset of IDs
1 parent 357c414 commit 99a1335

File tree

17 files changed

+190
-14
lines changed

17 files changed

+190
-14
lines changed

chromadb/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _query(
270270
self,
271271
collection_id: UUID,
272272
query_embeddings: Embeddings,
273+
ids: Optional[IDs] = None,
273274
n_results: int = 10,
274275
where: Optional[Where] = None,
275276
where_document: Optional[WhereDocument] = None,
@@ -280,6 +281,7 @@ def _query(
280281
Args:
281282
collection_id: The UUID of the collection to query.
282283
query_embeddings: The embeddings to use as the query.
284+
ids: The IDs to filter by during the query. Defaults to None.
283285
n_results: The number of results to return. Defaults to 10.
284286
where: Conditional filtering on metadata. Defaults to None.
285287
where_document: Conditional filtering on documents. Defaults to None.
@@ -724,6 +726,7 @@ def _query(
724726
self,
725727
collection_id: UUID,
726728
query_embeddings: Embeddings,
729+
ids: Optional[IDs] = None,
727730
n_results: int = 10,
728731
where: Optional[Where] = None,
729732
where_document: Optional[WhereDocument] = None,

chromadb/api/async_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ async def _query(
264264
self,
265265
collection_id: UUID,
266266
query_embeddings: Embeddings,
267+
ids: Optional[IDs] = None,
267268
n_results: int = 10,
268269
where: Optional[Where] = None,
269270
where_document: Optional[WhereDocument] = None,
@@ -718,6 +719,7 @@ async def _query(
718719
self,
719720
collection_id: UUID,
720721
query_embeddings: Embeddings,
722+
ids: Optional[IDs] = None,
721723
n_results: int = 10,
722724
where: Optional[Where] = None,
723725
where_document: Optional[WhereDocument] = None,

chromadb/api/async_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ async def _query(
403403
self,
404404
collection_id: UUID,
405405
query_embeddings: Embeddings,
406+
ids: Optional[IDs] = None,
406407
n_results: int = 10,
407408
where: Optional[Where] = None,
408409
where_document: Optional[WhereDocument] = None,
@@ -411,6 +412,7 @@ async def _query(
411412
return await self._server._query(
412413
collection_id=collection_id,
413414
query_embeddings=query_embeddings,
415+
ids=ids,
414416
n_results=n_results,
415417
where=where,
416418
where_document=where_document,

chromadb/api/async_fastapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ async def _query(
600600
self,
601601
collection_id: UUID,
602602
query_embeddings: Embeddings,
603+
ids: Optional[IDs] = None,
603604
n_results: int = 10,
604605
where: Optional[Where] = None,
605606
where_document: Optional[WhereDocument] = None,

chromadb/api/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,15 @@ def _query(
374374
self,
375375
collection_id: UUID,
376376
query_embeddings: Embeddings,
377+
ids: Optional[IDs] = None,
377378
n_results: int = 10,
378379
where: Optional[Where] = None,
379380
where_document: Optional[WhereDocument] = None,
380381
include: Include = IncludeMetadataDocumentsDistances,
381382
) -> QueryResult:
382383
return self._server._query(
383384
collection_id=collection_id,
385+
ids=ids,
384386
tenant=self.tenant,
385387
database=self.database,
386388
query_embeddings=query_embeddings,

chromadb/api/fastapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def _query(
570570
self,
571571
collection_id: UUID,
572572
query_embeddings: Embeddings,
573+
ids: Optional[IDs] = None,
573574
n_results: int = 10,
574575
where: Optional[Where] = None,
575576
where_document: Optional[WhereDocument] = None,
@@ -585,6 +586,7 @@ def _query(
585586
"post",
586587
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
587588
json={
589+
"ids": ids,
588590
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
589591
if query_embeddings is not None
590592
else None,

chromadb/api/models/AsyncCollection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ async def query(
169169
query_texts: Optional[OneOrMany[Document]] = None,
170170
query_images: Optional[OneOrMany[Image]] = None,
171171
query_uris: Optional[OneOrMany[URI]] = None,
172+
ids: Optional[IDs] = None,
172173
n_results: int = 10,
173174
where: Optional[Where] = None,
174175
where_document: Optional[WhereDocument] = None,
@@ -184,6 +185,7 @@ async def query(
184185
query_embeddings: The embeddings to get the closes neighbors of. Optional.
185186
query_texts: The document texts to get the closes neighbors of. Optional.
186187
query_images: The images to get the closes neighbors of. Optional.
188+
ids: A subset of ids to search within. Optional.
187189
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
188190
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
189191
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -213,6 +215,7 @@ async def query(
213215

214216
query_results = await self._client._query(
215217
collection_id=self.id,
218+
ids=ids,
216219
query_embeddings=query_request["embeddings"],
217220
n_results=query_request["n_results"],
218221
where=query_request["where"],

chromadb/api/models/Collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def query(
172172
query_texts: Optional[OneOrMany[Document]] = None,
173173
query_images: Optional[OneOrMany[Image]] = None,
174174
query_uris: Optional[OneOrMany[URI]] = None,
175+
ids: Optional[IDs] = None,
175176
n_results: int = 10,
176177
where: Optional[Where] = None,
177178
where_document: Optional[WhereDocument] = None,
@@ -188,6 +189,7 @@ def query(
188189
query_texts: The document texts to get the closes neighbors of. Optional.
189190
query_images: The images to get the closes neighbors of. Optional.
190191
query_uris: The URIs to be used with data loader. Optional.
192+
ids: A subset of ids to search within. Optional.
191193
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
192194
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
193195
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
@@ -217,6 +219,7 @@ def query(
217219

218220
query_results = self._client._query(
219221
collection_id=self.id,
222+
ids=ids,
220223
query_embeddings=query_request["embeddings"],
221224
n_results=query_request["n_results"],
222225
where=query_request["where"],

chromadb/api/rust.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def _query(
472472
self,
473473
collection_id: UUID,
474474
query_embeddings: Embeddings,
475+
ids: Optional[IDs] = None,
475476
n_results: int = 10,
476477
where: Optional[Where] = None,
477478
where_document: Optional[WhereDocument] = None,
@@ -480,10 +481,12 @@ def _query(
480481
database: str = DEFAULT_DATABASE,
481482
) -> QueryResult:
482483
query_amount = len(query_embeddings)
484+
filtered_ids_amount = len(ids) if ids else 0
483485
self.product_telemetry_client.capture(
484486
CollectionQueryEvent(
485487
collection_uuid=str(collection_id),
486488
query_amount=query_amount,
489+
filtered_ids_amount=filtered_ids_amount,
487490
n_results=n_results,
488491
with_metadata_filter=query_amount if where is not None else 0,
489492
with_document_filter=query_amount if where_document is not None else 0,
@@ -496,6 +499,7 @@ def _query(
496499

497500
rust_response = self.bindings.query(
498501
str(collection_id),
502+
ids,
499503
query_embeddings,
500504
n_results,
501505
json.dumps(where) if where else None,

chromadb/api/segment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ def _query(
775775
self,
776776
collection_id: UUID,
777777
query_embeddings: Embeddings,
778+
ids: Optional[IDs] = None,
778779
n_results: int = 10,
779780
where: Optional[Where] = None,
780781
where_document: Optional[WhereDocument] = None,
@@ -791,10 +792,12 @@ def _query(
791792
)
792793

793794
query_amount = len(query_embeddings)
795+
ids_amount = len(ids) if ids else 0
794796
self._product_telemetry_client.capture(
795797
CollectionQueryEvent(
796798
collection_uuid=str(collection_id),
797799
query_amount=query_amount,
800+
filtered_ids_amount=ids_amount,
798801
n_results=n_results,
799802
with_metadata_filter=query_amount if where is not None else 0,
800803
with_document_filter=query_amount if where_document is not None else 0,

chromadb/telemetry/product/events.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CollectionQueryEvent(ProductTelemetryEvent):
137137
batch_size: int
138138
collection_uuid: str
139139
query_amount: int
140+
filtered_ids_amount: int
140141
with_metadata_filter: int
141142
with_document_filter: int
142143
n_results: int
@@ -149,6 +150,7 @@ def __init__(
149150
self,
150151
collection_uuid: str,
151152
query_amount: int,
153+
filtered_ids_amount: int,
152154
with_metadata_filter: int,
153155
with_document_filter: int,
154156
n_results: int,
@@ -161,6 +163,7 @@ def __init__(
161163
super().__init__()
162164
self.collection_uuid = collection_uuid
163165
self.query_amount = query_amount
166+
self.filtered_ids_amount = filtered_ids_amount
164167
self.with_metadata_filter = with_metadata_filter
165168
self.with_document_filter = with_document_filter
166169
self.n_results = n_results
@@ -182,6 +185,7 @@ def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent":
182185
return CollectionQueryEvent(
183186
collection_uuid=self.collection_uuid,
184187
query_amount=total_amount,
188+
filtered_ids_amount=self.filtered_ids_amount + other.filtered_ids_amount,
185189
with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter,
186190
with_document_filter=self.with_document_filter + other.with_document_filter,
187191
n_results=self.n_results + other.n_results,

chromadb/test/property/test_filtering.py

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import chromadb.test.property.strategies as strategies
2020
import hypothesis.strategies as st
2121
import logging
22-
import random
2322
import re
2423
from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase
2524
import numpy as np
@@ -325,6 +324,7 @@ def test_filterable_metadata_get_limit_offset(
325324
min_size=1,
326325
),
327326
should_compact=st.booleans(),
327+
data=st.data(),
328328
)
329329
def test_filterable_metadata_query(
330330
caplog: pytest.LogCaptureFixture,
@@ -333,6 +333,7 @@ def test_filterable_metadata_query(
333333
record_set: strategies.RecordSet,
334334
filters: List[strategies.Filter],
335335
should_compact: bool,
336+
data: st.DataObject,
336337
) -> None:
337338
caplog.set_level(logging.ERROR)
338339

@@ -355,19 +356,21 @@ def test_filterable_metadata_query(
355356
wait_for_version_increase(client, collection.name, initial_version) # type: ignore
356357

357358
total_count = len(normalized_record_set["ids"])
358-
# Pick a random vector
359+
# Pick a random vector using Hypothesis data
359360
random_query: Embedding
361+
362+
query_index = data.draw(st.integers(min_value=0, max_value=total_count - 1))
360363
if collection.has_embeddings:
361364
assert normalized_record_set["embeddings"] is not None
362365
assert all(isinstance(e, list) for e in normalized_record_set["embeddings"])
363-
random_query = normalized_record_set["embeddings"][
364-
random.randint(0, total_count - 1)
365-
]
366+
# Use data.draw to select index
367+
random_query = normalized_record_set["embeddings"][query_index]
366368
else:
367369
assert isinstance(normalized_record_set["documents"], list)
368370
assert collection.embedding_function is not None
371+
# Use data.draw to select index
369372
random_query = collection.embedding_function(
370-
[normalized_record_set["documents"][random.randint(0, total_count - 1)]]
373+
[normalized_record_set["documents"][query_index]]
371374
)[0]
372375
for filter in filters:
373376
result_ids = set(
@@ -402,7 +405,7 @@ def test_empty_filter(client: ClientAPI) -> None:
402405
query_embeddings=test_query_embedding,
403406
where={"q": {"$eq": 4}}, # type: ignore[dict-item]
404407
n_results=3,
405-
include=["embeddings", "distances", "metadatas"], # type: ignore[list-item]
408+
include=["embeddings", "distances", "metadatas"],
406409
)
407410
assert res["ids"] == [[]]
408411
if res["embeddings"] is not None:
@@ -459,9 +462,108 @@ def check_empty_res(res: GetResult) -> None:
459462

460463
coll.add(ids=test_ids, embeddings=test_embeddings, metadatas=test_metadatas)
461464

462-
res = coll.get(ids=["nope"], include=["embeddings", "metadatas", "documents"]) # type: ignore[list-item]
465+
res = coll.get(ids=["nope"], include=["embeddings", "metadatas", "documents"])
463466
check_empty_res(res)
464467
res = coll.get(
465-
include=["embeddings", "metadatas", "documents"], where={"test": 100} # type: ignore[list-item]
468+
include=["embeddings", "metadatas", "documents"], where={"test": 100}
466469
)
467470
check_empty_res(res)
471+
472+
473+
@settings(
474+
deadline=90000,
475+
suppress_health_check=[
476+
HealthCheck.function_scoped_fixture,
477+
HealthCheck.large_base_example,
478+
],
479+
)
480+
@given(
481+
collection=collection_st,
482+
record_set=recordset_st,
483+
n_results_st=st.integers(min_value=1, max_value=100),
484+
should_compact=st.booleans(),
485+
data=st.data(),
486+
)
487+
def test_query_ids_filter_property(
488+
caplog: pytest.LogCaptureFixture,
489+
client: ClientAPI,
490+
collection: strategies.Collection,
491+
record_set: strategies.RecordSet,
492+
n_results_st: int,
493+
should_compact: bool,
494+
data: st.DataObject,
495+
) -> None:
496+
"""Property test for querying with only the ids filter."""
497+
if (
498+
client.get_settings().chroma_api_impl
499+
== "chromadb.api.async_fastapi.AsyncFastAPI"
500+
):
501+
pytest.skip(
502+
"Skipping test for async client due to potential resource/timeout issues"
503+
)
504+
caplog.set_level(logging.ERROR)
505+
reset(client)
506+
coll = client.create_collection(
507+
name=collection.name,
508+
metadata=collection.metadata, # type: ignore
509+
embedding_function=collection.embedding_function,
510+
)
511+
initial_version = coll.get_model()["version"]
512+
normalized_record_set = invariants.wrap_all(record_set)
513+
514+
if len(normalized_record_set["ids"]) == 0:
515+
# Cannot add empty record set
516+
return
517+
518+
coll.add(**record_set) # type: ignore[arg-type]
519+
520+
if not NOT_CLUSTER_ONLY:
521+
if should_compact and len(normalized_record_set["ids"]) > 10:
522+
wait_for_version_increase(client, collection.name, initial_version) # type: ignore
523+
524+
total_count = len(normalized_record_set["ids"])
525+
n_results = min(n_results_st, total_count)
526+
527+
# Generate a random subset of ids to filter on using Hypothesis data
528+
ids_subset_size = data.draw(st.integers(min_value=0, max_value=total_count))
529+
ids_to_query = data.draw(
530+
st.lists(
531+
st.sampled_from(normalized_record_set["ids"]),
532+
min_size=ids_subset_size,
533+
max_size=ids_subset_size,
534+
unique=True,
535+
)
536+
)
537+
538+
# Pick a random query vector using Hypothesis data
539+
random_query: Embedding
540+
query_index = data.draw(st.integers(min_value=0, max_value=total_count - 1))
541+
if collection.has_embeddings:
542+
assert normalized_record_set["embeddings"] is not None
543+
assert all(isinstance(e, list) for e in normalized_record_set["embeddings"])
544+
# Use data.draw to select index
545+
random_query = normalized_record_set["embeddings"][query_index]
546+
else:
547+
assert isinstance(normalized_record_set["documents"], list)
548+
assert collection.embedding_function is not None
549+
# Use data.draw to select index
550+
random_query = collection.embedding_function(
551+
[normalized_record_set["documents"][query_index]]
552+
)[0]
553+
554+
# Perform the query with only the ids filter
555+
result = coll.query(
556+
query_embeddings=[random_query],
557+
ids=ids_to_query,
558+
n_results=n_results,
559+
)
560+
561+
result_ids = set(result["ids"][0])
562+
filter_ids_set = set(ids_to_query)
563+
564+
# The core assertion: all returned IDs must be within the filter set
565+
assert result_ids.issubset(filter_ids_set)
566+
567+
# Also check that the number of results is reasonable
568+
assert len(result_ids) <= n_results
569+
assert len(result_ids) <= len(filter_ids_set)

0 commit comments

Comments
 (0)