Skip to content

Commit 7a3f562

Browse files
authored
[CHORE] Add test to validate embedding function error handling (#4259)
## Description of changes This PR adds tests for validation_context to ensure the errors raised match the ones given by the caller, in this case an embedding function. tests for #4235 ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
1 parent 61a009a commit 7a3f562

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

chromadb/api/models/CollectionCommon.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
9494
try:
9595
return func(self, *args, **kwargs)
9696
except Exception as e:
97-
# modify the error message
9897
msg = f"{str(e)} in {name}."
9998
# add the rest of the args to the error message if they exist
10099
e.args = (msg,) + e.args[1:] if e.args else ()

chromadb/test/ef/test_ef.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
from chromadb.utils import embedding_functions
2-
from chromadb.utils.embedding_functions import EmbeddingFunction
2+
from chromadb.utils.embedding_functions import (
3+
EmbeddingFunction,
4+
register_embedding_function,
5+
)
6+
from typing import Dict, Any
7+
import pytest
8+
from chromadb.api.types import (
9+
Embeddings,
10+
Space,
11+
Embeddable,
12+
)
13+
from chromadb.api.models.CollectionCommon import validation_context
314

415

516
def test_get_builtins_holds() -> None:
@@ -54,3 +65,39 @@ def test_ef_imports() -> None:
5465
assert hasattr(embedding_functions, ef)
5566
assert isinstance(getattr(embedding_functions, ef), type)
5667
assert issubclass(getattr(embedding_functions, ef), EmbeddingFunction)
68+
69+
70+
@register_embedding_function
71+
class CustomEmbeddingFunction(EmbeddingFunction[Embeddable]):
72+
def __init__(self, dim: int = 3):
73+
self._dim = dim
74+
75+
@validation_context("custom_ef_call")
76+
def __call__(self, input: Embeddable) -> Embeddings:
77+
raise Exception("This is a test exception")
78+
79+
@staticmethod
80+
def name() -> str:
81+
return "custom_ef"
82+
83+
def get_config(self) -> Dict[str, Any]:
84+
return {"dim": self._dim}
85+
86+
@staticmethod
87+
def build_from_config(config: Dict[str, Any]) -> "CustomEmbeddingFunction":
88+
return CustomEmbeddingFunction(dim=config["dim"])
89+
90+
def default_space(self) -> Space:
91+
return "cosine"
92+
93+
94+
def test_validation_context_with_custom_ef() -> None:
95+
custom_ef = CustomEmbeddingFunction()
96+
97+
with pytest.raises(Exception) as excinfo:
98+
custom_ef(["test data"])
99+
100+
original_msg = "This is a test exception"
101+
expected_msg = f"{original_msg} in custom_ef_call."
102+
assert str(excinfo.value) == expected_msg
103+
assert excinfo.value.args == (expected_msg,)

chromadb/test/ef/test_openai_ef.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,10 @@ def test_with_embedding_dimensions_not_working_with_old_model() -> None:
2929
Exception, match="This model does not support specifying dimensions"
3030
):
3131
ef(["hello world"])
32+
33+
34+
def test_with_incorrect_api_key() -> None:
35+
pytest.importorskip("openai", reason="openai not installed")
36+
ef = OpenAIEmbeddingFunction(api_key="incorrect_api_key", dimensions=64)
37+
with pytest.raises(Exception, match="Incorrect API key provided"):
38+
ef(["hello world"])

0 commit comments

Comments
 (0)