|
1 | 1 | from chromadb.utils import embedding_functions
|
2 |
| -from chromadb.utils.embedding_functions import EmbeddingFunction |
| 2 | +from chromadb.utils.embedding_functions import ( |
| 3 | + EmbeddingFunction, |
| 4 | + register_embedding_function, |
| 5 | +) |
| 6 | +from typing import Dict, Any |
| 7 | +import pytest |
| 8 | +from chromadb.api.types import ( |
| 9 | + Embeddings, |
| 10 | + Space, |
| 11 | + Embeddable, |
| 12 | +) |
| 13 | +from chromadb.api.models.CollectionCommon import validation_context |
3 | 14 |
|
4 | 15 |
|
5 | 16 | def test_get_builtins_holds() -> None:
|
@@ -54,3 +65,39 @@ def test_ef_imports() -> None:
|
54 | 65 | assert hasattr(embedding_functions, ef)
|
55 | 66 | assert isinstance(getattr(embedding_functions, ef), type)
|
56 | 67 | assert issubclass(getattr(embedding_functions, ef), EmbeddingFunction)
|
| 68 | + |
| 69 | + |
| 70 | +@register_embedding_function |
| 71 | +class CustomEmbeddingFunction(EmbeddingFunction[Embeddable]): |
| 72 | + def __init__(self, dim: int = 3): |
| 73 | + self._dim = dim |
| 74 | + |
| 75 | + @validation_context("custom_ef_call") |
| 76 | + def __call__(self, input: Embeddable) -> Embeddings: |
| 77 | + raise Exception("This is a test exception") |
| 78 | + |
| 79 | + @staticmethod |
| 80 | + def name() -> str: |
| 81 | + return "custom_ef" |
| 82 | + |
| 83 | + def get_config(self) -> Dict[str, Any]: |
| 84 | + return {"dim": self._dim} |
| 85 | + |
| 86 | + @staticmethod |
| 87 | + def build_from_config(config: Dict[str, Any]) -> "CustomEmbeddingFunction": |
| 88 | + return CustomEmbeddingFunction(dim=config["dim"]) |
| 89 | + |
| 90 | + def default_space(self) -> Space: |
| 91 | + return "cosine" |
| 92 | + |
| 93 | + |
| 94 | +def test_validation_context_with_custom_ef() -> None: |
| 95 | + custom_ef = CustomEmbeddingFunction() |
| 96 | + |
| 97 | + with pytest.raises(Exception) as excinfo: |
| 98 | + custom_ef(["test data"]) |
| 99 | + |
| 100 | + original_msg = "This is a test exception" |
| 101 | + expected_msg = f"{original_msg} in custom_ef_call." |
| 102 | + assert str(excinfo.value) == expected_msg |
| 103 | + assert excinfo.value.args == (expected_msg,) |
0 commit comments