diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py index c1da2b902c5..82ed00bf0dd 100644 --- a/chromadb/test/ef/test_ef.py +++ b/chromadb/test/ef/test_ef.py @@ -26,6 +26,7 @@ def test_get_builtins_holds() -> None: expected_builtins = { "AmazonBedrockEmbeddingFunction", "BasetenEmbeddingFunction", + "CloudflareWorkersAIEmbeddingFunction", "CohereEmbeddingFunction", "VoyageAIEmbeddingFunction", "GoogleGenerativeAiEmbeddingFunction", diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py index b0d7357471e..21233e3e4ae 100644 --- a/chromadb/utils/embedding_functions/__init__.py +++ b/chromadb/utils/embedding_functions/__init__.py @@ -55,6 +55,9 @@ from chromadb.utils.embedding_functions.baseten_embedding_function import ( BasetenEmbeddingFunction, ) +from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import ( + CloudflareWorkersAIEmbeddingFunction, +) try: from chromadb.is_thin_client import is_thin_client @@ -82,6 +85,7 @@ "AmazonBedrockEmbeddingFunction", "ChromaLangchainEmbeddingFunction", "BasetenEmbeddingFunction", + "CloudflareWorkersAIEmbeddingFunction", "DefaultEmbeddingFunction", } @@ -141,6 +145,7 @@ def validate_config(config: Dict[str, Any]) -> None: "chroma_langchain": ChromaLangchainEmbeddingFunction, "baseten": BasetenEmbeddingFunction, "default": DefaultEmbeddingFunction, + "cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction, } @@ -207,6 +212,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction: "CohereEmbeddingFunction", "OpenAIEmbeddingFunction", "BasetenEmbeddingFunction", + "CloudflareWorkersAIEmbeddingFunction", "HuggingFaceEmbeddingFunction", "HuggingFaceEmbeddingServer", "SentenceTransformerEmbeddingFunction", diff --git a/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py new file mode 100644 index 00000000000..de9d93a0784 --- /dev/null +++ b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py @@ -0,0 +1,144 @@ +from chromadb.api.types import ( + Embeddings, + Documents, + EmbeddingFunction, + Space, +) +from typing import List, Dict, Any, Optional +import os +from chromadb.utils.embedding_functions.schemas import validate_config_schema +from typing import cast + +BASE_URL = "https://api.cloudflare.com/client/v4/accounts" +GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1" + + +class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the Cloudflare Workers AI API. + It requires an API key and a model name. + """ + + def __init__( + self, + model_name: str, + account_id: str, + api_key: Optional[str] = None, + api_key_env_var: str = "CHROMA_CLOUDFLARE_API_KEY", + gateway_id: Optional[str] = None, + ): + """ + Initialize the CloudflareWorkersAIEmbeddingFunction. See the docs for supported models here: + https://developers.cloudflare.com/workers-ai/models/ + + Args: + model_name: The name of the model to use for text embeddings. + account_id: The account ID for the Cloudflare Workers AI API. + api_key: The API key for the Cloudflare Workers AI API. + api_key_env_var: The environment variable name for the Cloudflare Workers AI API key. + """ + try: + import httpx + except ImportError: + raise ValueError( + "The httpx python package is not installed. Please install it with `pip install httpx`" + ) + self.model_name = model_name + self.account_id = account_id + self.api_key_env_var = api_key_env_var + self.api_key = api_key or os.getenv(api_key_env_var) + self.gateway_id = gateway_id + + if not self.api_key: + raise ValueError(f"The {api_key_env_var} environment variable is not set.") + + if self.gateway_id: + self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}" + else: + self._api_url = f"{BASE_URL}/{self.account_id}/ai/run/{self.model_name}" + + self._session = httpx.Client() + self._session.headers.update( + {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} + ) + + def __call__(self, input: Documents) -> Embeddings: + """ + Generate embeddings for the given documents. + + Args: + input: Documents to generate embeddings for. + + Returns: + Embeddings for the documents. + """ + if not all(isinstance(item, str) for item in input): + raise ValueError( + "Cloudflare Workers AI only supports text documents, not images" + ) + + payload: Dict[str, Any] = { + "text": input, + } + + resp = self._session.post(self._api_url, json=payload).json() + + if "result" not in resp and "data" not in resp["result"]: + raise RuntimeError(resp.get("detail", "Unknown error")) + + return cast(Embeddings, resp["result"]["data"]) + + @staticmethod + def name() -> str: + return "cloudflare_workers_ai" + + def default_space(self) -> Space: + return "cosine" + + def supported_spaces(self) -> List[Space]: + return ["cosine", "l2", "ip"] + + @staticmethod + def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": + api_key_env_var = config.get("api_key_env_var") + model_name = config.get("model_name") + account_id = config.get("account_id") + gateway_id = config.get("gateway_id", None) + if api_key_env_var is None or model_name is None or account_id is None: + assert False, "This code should not be reached" + + return CloudflareWorkersAIEmbeddingFunction( + api_key_env_var=api_key_env_var, + model_name=model_name, + account_id=account_id, + gateway_id=gateway_id, + ) + + def get_config(self) -> Dict[str, Any]: + return { + "api_key_env_var": self.api_key_env_var, + "model_name": self.model_name, + "account_id": self.account_id, + "gateway_id": self.gateway_id, + } + + def validate_config_update( + self, old_config: Dict[str, Any], new_config: Dict[str, Any] + ) -> None: + if "model_name" in new_config: + raise ValueError( + "The model name cannot be changed after the embedding function has been initialized." + ) + + @staticmethod + def validate_config(config: Dict[str, Any]) -> None: + """ + Validate the configuration using the JSON schema. + + Args: + config: Configuration to validate + + Raises: + ValidationError: If the configuration does not match the schema + """ + validate_config_schema(config, "cloudflare_workers_ai") diff --git a/clients/js/packages/chromadb-core/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts b/clients/js/packages/chromadb-core/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts new file mode 100644 index 00000000000..75812d6ba44 --- /dev/null +++ b/clients/js/packages/chromadb-core/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts @@ -0,0 +1,115 @@ +import { IEmbeddingFunction } from "./IEmbeddingFunction"; +import { validateConfigSchema } from "../schemas/schemaUtils"; + +type StoredConfig = { + account_id: string; + model_name: string; + api_key_env_var: string; + gateway_id?: string; +}; + +const BASE_URL = "https://api.cloudflare.com/client/v4/accounts"; +const GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1"; + +export class CloudflareWorkersAIEmbeddingFunction + implements IEmbeddingFunction +{ + name = "cloudflare_workers_ai"; + + private account_id: string; + private model_name: string; + private api_key_env_var: string; + private gateway_id?: string; + private api_url: string; + private headers: { [key: string]: string }; + + constructor({ + cloudflare_api_key, + model_name, + account_id, + api_key_env_var = "CHROMA_CLOUDFLARE_API_KEY", + gateway_id = undefined, + }: { + cloudflare_api_key?: string; + model_name: string; + account_id: string; + api_key_env_var: string; + gateway_id?: string; + }) { + const apiKey = cloudflare_api_key ?? process.env[api_key_env_var]; + if (!apiKey) { + throw new Error( + `Cloudflare API key is required. Please provide it in the constructor or set the environment variable ${api_key_env_var}.`, + ); + } + + this.model_name = model_name; + this.account_id = account_id; + this.api_key_env_var = api_key_env_var; + this.gateway_id = gateway_id; + + if (this.gateway_id) { + this.api_url = `${GATEWAY_BASE_URL}/${this.account_id}/${this.gateway_id}/workers-ai/${this.model_name}`; + } else { + this.api_url = `${BASE_URL}/${this.account_id}/ai/run/${this.model_name}`; + } + + this.headers = { + Authorization: `Bearer ${apiKey}`, + "Accept-Encoding": "identity", + "Content-Type": "application/json", + }; + } + + public async generate(texts: string[]) { + try { + const payload = { + text: texts, + }; + + const response = await fetch(this.api_url, { + method: "POST", + headers: this.headers, + body: JSON.stringify(payload), + }); + + const resp = await response.json(); + + if (!resp.result || !resp.result.data) { + throw new Error(resp.detail || "Unknown error"); + } + + return resp.result.data; + } catch (error) { + if (error instanceof Error) { + throw new Error( + `Error calling Cloudflare Workers AI API: ${error.message}`, + ); + } else { + throw new Error(`Error calling Cloudflare Workers AI API: ${error}`); + } + } + } + + buildFromConfig(config: StoredConfig): CloudflareWorkersAIEmbeddingFunction { + return new CloudflareWorkersAIEmbeddingFunction({ + model_name: config.model_name, + account_id: config.account_id, + api_key_env_var: config.api_key_env_var, + gateway_id: config.gateway_id ?? undefined, + }); + } + + getConfig(): StoredConfig { + return { + model_name: this.model_name, + account_id: this.account_id, + api_key_env_var: this.api_key_env_var, + gateway_id: this.gateway_id ?? undefined, + }; + } + + validateConfig(config: StoredConfig): void { + validateConfigSchema(config, "cloudflare_workers_ai"); + } +} diff --git a/clients/js/packages/chromadb-core/src/schemas/schemaUtils.ts b/clients/js/packages/chromadb-core/src/schemas/schemaUtils.ts index 8fbdd14bcf3..6d0d98b87bd 100644 --- a/clients/js/packages/chromadb-core/src/schemas/schemaUtils.ts +++ b/clients/js/packages/chromadb-core/src/schemas/schemaUtils.ts @@ -20,6 +20,7 @@ import sentenceTransformerSchema from "../../../../../../schemas/embedding_funct import text2vecSchema from "../../../../../../schemas/embedding_functions/text2vec.json"; import transformersSchema from "../../../../../../schemas/embedding_functions/transformers.json"; import voyageaiSchema from "../../../../../../schemas/embedding_functions/voyageai.json"; +import cloudflareWorkersAiSchema from "../../../../../../schemas/embedding_functions/cloudflare_workers_ai.json"; import Ajv from "ajv"; @@ -64,6 +65,7 @@ const schemaMap = { text2vec: text2vecSchema as Schema, transformers: transformersSchema as Schema, voyageai: voyageaiSchema as Schema, + cloudflare_workers_ai: cloudflareWorkersAiSchema as Schema, }; /** diff --git a/docs/docs.trychroma.com/markdoc/content/docs/embeddings/embedding-functions.md b/docs/docs.trychroma.com/markdoc/content/docs/embeddings/embedding-functions.md index 365c880f0bc..19a762546e4 100644 --- a/docs/docs.trychroma.com/markdoc/content/docs/embeddings/embedding-functions.md +++ b/docs/docs.trychroma.com/markdoc/content/docs/embeddings/embedding-functions.md @@ -13,6 +13,7 @@ Chroma provides lightweight wrappers around popular embedding providers, making | [Instructor](../../integrations/embedding-models/instructor) | ✓ | - | | [Hugging Face Embedding Server](../../integrations/embedding-models/hugging-face-server) | ✓ | ✓ | | [Jina AI](../../integrations/embedding-models/jina-ai) | ✓ | ✓ | +| [Cloudflare Workers AI](../../integrations/embedding-models/cloudflare-workers-ai.md) | ✓ | ✓ | We welcome pull requests to add new Embedding Functions to the community. diff --git a/docs/docs.trychroma.com/markdoc/content/integrations/chroma-integrations.md b/docs/docs.trychroma.com/markdoc/content/integrations/chroma-integrations.md index d7a98f040d2..7378e4a3236 100644 --- a/docs/docs.trychroma.com/markdoc/content/integrations/chroma-integrations.md +++ b/docs/docs.trychroma.com/markdoc/content/integrations/chroma-integrations.md @@ -21,6 +21,7 @@ Chroma provides lightweight wrappers around popular embedding providers, making | [Jina AI](./embedding-models/jina-ai) | ✓ | ✓ | | [Roboflow](./embedding-models/roboflow) | ✓ | - | | [Ollama Embeddings](./embedding-models/ollama) | ✓ | ✓ | +| [Cloudflare Workers AI](./embedding-models/cloudflare-workers-ai.md) | ✓ | ✓ | --- diff --git a/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/cloudflare-workers-ai.md b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/cloudflare-workers-ai.md new file mode 100644 index 00000000000..1cb65dc4f9e --- /dev/null +++ b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/cloudflare-workers-ai.md @@ -0,0 +1,50 @@ +--- +id: cloudflare-workers-ai +name: Cloudflare Workers AI +--- + +# Cloudflare Workers AI + +Chroma provides a wrapper around Cloudflare Workers AI embedding models. This embedding function runs remotely against the Cloudflare Workers AI servers, and will require an API key and a Cloudflare account. You can find more information in the [Cloudflare Workers AI Docs](https://developers.cloudflare.com/workers-ai/). + +You can also optionally use the Cloudflare AI Gateway for a more customized solution by setting a `gateway_id` argument. See the [Cloudflare AI Gateway Docs](https://developers.cloudflare.com/ai-gateway/providers/workersai/) for more info. + +{% TabbedCodeBlock %} + +{% Tab label="python" %} + +```python +from chromadb.utils.embedding_functions import CloudflareWorkersAIEmbeddingFunction + +os.environ["CHROMA_CLOUDFLARE_API_KEY"] = "" + +ef = CloudflareWorkersAIEmbeddingFunction( + account_id="bd4502421ad9c8e8931d02a616e6845a", + model_name="@cf/baai/bge-m3", + ) +ef(input=["This is my first text to embed", "This is my second document"]) +``` + +{% /Tab %} + +{% Tab label="typescript" %} + +```typescript +import { JinaEmbeddingFunction } from 'chromadb'; + +process.env.CHROMA_CLOUDFLARE_API_KEY = "" + +const embedder = new CloudflareWorkersAIEmbeddingFunction({ + account_id="bd4502421ad9c8e8931d02a616e6845a", + model_name="@cf/baai/bge-m3", +}); + +// use directly +embedder.generate(['This is my first text to embed', 'This is my second document']); +``` + +{% /Tab %} + +{% /TabbedCodeBlock %} + +You must pass in an `account_id` and `model_name` to the embedding function. It is recommended to set the `CHROMA_CLOUDFLARE_API_KEY` for the api key, but the embedding function also optionally takes in an `api_key` variable. diff --git a/schemas/embedding_functions/cloudflare_workers_ai.json b/schemas/embedding_functions/cloudflare_workers_ai.json new file mode 100644 index 00000000000..90530532265 --- /dev/null +++ b/schemas/embedding_functions/cloudflare_workers_ai.json @@ -0,0 +1,31 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Cloudflare Workers AI Embedding Function Schema", + "description": "Schema for the Cloudflare Workers AI embedding function configuration", + "version": "1.0.0", + "type": "object", + "properties": { + "model_name": { + "type": "string", + "description": "The name of the model to use for text embeddings" + }, + "account_id": { + "type": "string", + "description": "The account ID for the Cloudflare Workers AI API" + }, + "api_key_env_var": { + "type": "string", + "description": "The environment variable name that contains your API key for the Cloudflare Workers AI API" + }, + "gateway_id": { + "type": "string", + "description": "The ID of the Cloudflare AI Gateway to use for a more customized solution" + } + }, + "required": [ + "api_key_env_var", + "model_name", + "account_id" + ], + "additionalProperties": false +}