Skip to content

[ENH] Add Cloudflare Worker AI Embedding Function #4389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chromadb/test/ef/test_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_get_builtins_holds() -> None:
expected_builtins = {
"AmazonBedrockEmbeddingFunction",
"BasetenEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
"CohereEmbeddingFunction",
"VoyageAIEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
Expand Down
6 changes: 6 additions & 0 deletions chromadb/utils/embedding_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
from chromadb.utils.embedding_functions.baseten_embedding_function import (
BasetenEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import (
CloudflareWorkersAIEmbeddingFunction,
)

try:
from chromadb.is_thin_client import is_thin_client
Expand Down Expand Up @@ -82,6 +85,7 @@
"AmazonBedrockEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
"BasetenEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
"DefaultEmbeddingFunction",
}

Expand Down Expand Up @@ -141,6 +145,7 @@ def validate_config(config: Dict[str, Any]) -> None:
"chroma_langchain": ChromaLangchainEmbeddingFunction,
"baseten": BasetenEmbeddingFunction,
"default": DefaultEmbeddingFunction,
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
}


Expand Down Expand Up @@ -207,6 +212,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
"CohereEmbeddingFunction",
"OpenAIEmbeddingFunction",
"BasetenEmbeddingFunction",
"CloudflareWorkersAIEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"SentenceTransformerEmbeddingFunction",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from chromadb.api.types import (
Embeddings,
Documents,
EmbeddingFunction,
Space,
)
from typing import List, Dict, Any, Optional
import os
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from typing import cast

BASE_URL = "https://api.cloudflare.com/client/v4/accounts"
GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1"


class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the Cloudflare Workers AI API.
It requires an API key and a model name.
"""

def __init__(
self,
model_name: str,
account_id: str,
api_key: Optional[str] = None,
api_key_env_var: str = "CHROMA_CLOUDFLARE_API_KEY",
gateway_id: Optional[str] = None,
):
"""
Initialize the CloudflareWorkersAIEmbeddingFunction. See the docs for supported models here:
https://developers.cloudflare.com/workers-ai/models/

Args:
model_name: The name of the model to use for text embeddings.
account_id: The account ID for the Cloudflare Workers AI API.
api_key: The API key for the Cloudflare Workers AI API.
api_key_env_var: The environment variable name for the Cloudflare Workers AI API key.
"""
try:
import httpx
except ImportError:
raise ValueError(
"The httpx python package is not installed. Please install it with `pip install httpx`"
)
self.model_name = model_name
self.account_id = account_id
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
self.gateway_id = gateway_id

if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")

if self.gateway_id:
self._api_url = f"{GATEWAY_BASE_URL}/{self.account_id}/{self.gateway_id}/workers-ai/{self.model_name}"
else:
self._api_url = f"{BASE_URL}/{self.account_id}/ai/run/{self.model_name}"

self._session = httpx.Client()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)

def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.

Args:
input: Documents to generate embeddings for.

Returns:
Embeddings for the documents.
"""
if not all(isinstance(item, str) for item in input):
raise ValueError(
"Cloudflare Workers AI only supports text documents, not images"
)

payload: Dict[str, Any] = {
"text": input,
}

resp = self._session.post(self._api_url, json=payload).json()

if "result" not in resp and "data" not in resp["result"]:
raise RuntimeError(resp.get("detail", "Unknown error"))

return cast(Embeddings, resp["result"]["data"])

@staticmethod
def name() -> str:
return "cloudflare_workers_ai"

def default_space(self) -> Space:
return "cosine"

def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
account_id = config.get("account_id")
gateway_id = config.get("gateway_id", None)
if api_key_env_var is None or model_name is None or account_id is None:
assert False, "This code should not be reached"

return CloudflareWorkersAIEmbeddingFunction(
api_key_env_var=api_key_env_var,
model_name=model_name,
account_id=account_id,
gateway_id=gateway_id,
)

def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
"account_id": self.account_id,
"gateway_id": self.gateway_id,
}

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.

Args:
config: Configuration to validate

Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "cloudflare_workers_ai")
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";
import { validateConfigSchema } from "../schemas/schemaUtils";

type StoredConfig = {
account_id: string;
model_name: string;
api_key_env_var: string;
gateway_id?: string;
};

const BASE_URL = "https://api.cloudflare.com/client/v4/accounts";
const GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1";

export class CloudflareWorkersAIEmbeddingFunction
implements IEmbeddingFunction
{
name = "cloudflare_workers_ai";

private account_id: string;
private model_name: string;
private api_key_env_var: string;
private gateway_id?: string;
private api_url: string;
private headers: { [key: string]: string };

constructor({
cloudflare_api_key,
model_name,
account_id,
api_key_env_var = "CHROMA_CLOUDFLARE_API_KEY",
gateway_id = undefined,
}: {
cloudflare_api_key?: string;
model_name: string;
account_id: string;
api_key_env_var: string;
gateway_id?: string;
}) {
const apiKey = cloudflare_api_key ?? process.env[api_key_env_var];
if (!apiKey) {
throw new Error(
`Cloudflare API key is required. Please provide it in the constructor or set the environment variable ${api_key_env_var}.`,
);
}

this.model_name = model_name;
this.account_id = account_id;
this.api_key_env_var = api_key_env_var;
this.gateway_id = gateway_id;

if (this.gateway_id) {
this.api_url = `${GATEWAY_BASE_URL}/${this.account_id}/${this.gateway_id}/workers-ai/${this.model_name}`;
} else {
this.api_url = `${BASE_URL}/${this.account_id}/ai/run/${this.model_name}`;
}

this.headers = {
Authorization: `Bearer ${apiKey}`,
"Accept-Encoding": "identity",
"Content-Type": "application/json",
};
}

public async generate(texts: string[]) {
try {
const payload = {
text: texts,
};

const response = await fetch(this.api_url, {
method: "POST",
headers: this.headers,
body: JSON.stringify(payload),
});

const resp = await response.json();

if (!resp.result || !resp.result.data) {
throw new Error(resp.detail || "Unknown error");
}

return resp.result.data;
} catch (error) {
if (error instanceof Error) {
throw new Error(
`Error calling Cloudflare Workers AI API: ${error.message}`,
);
} else {
throw new Error(`Error calling Cloudflare Workers AI API: ${error}`);
}
}
}

buildFromConfig(config: StoredConfig): CloudflareWorkersAIEmbeddingFunction {
return new CloudflareWorkersAIEmbeddingFunction({
model_name: config.model_name,
account_id: config.account_id,
api_key_env_var: config.api_key_env_var,
gateway_id: config.gateway_id ?? undefined,
});
}

getConfig(): StoredConfig {
return {
model_name: this.model_name,
account_id: this.account_id,
api_key_env_var: this.api_key_env_var,
gateway_id: this.gateway_id ?? undefined,
};
}

validateConfig(config: StoredConfig): void {
validateConfigSchema(config, "cloudflare_workers_ai");
}
}
2 changes: 2 additions & 0 deletions clients/js/packages/chromadb-core/src/schemas/schemaUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import sentenceTransformerSchema from "../../../../../../schemas/embedding_funct
import text2vecSchema from "../../../../../../schemas/embedding_functions/text2vec.json";
import transformersSchema from "../../../../../../schemas/embedding_functions/transformers.json";
import voyageaiSchema from "../../../../../../schemas/embedding_functions/voyageai.json";
import cloudflareWorkersAiSchema from "../../../../../../schemas/embedding_functions/cloudflare_workers_ai.json";

import Ajv from "ajv";

Expand Down Expand Up @@ -64,6 +65,7 @@ const schemaMap = {
text2vec: text2vecSchema as Schema,
transformers: transformersSchema as Schema,
voyageai: voyageaiSchema as Schema,
cloudflare_workers_ai: cloudflareWorkersAiSchema as Schema,
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Chroma provides lightweight wrappers around popular embedding providers, making
| [Instructor](../../integrations/embedding-models/instructor) | ✓ | - |
| [Hugging Face Embedding Server](../../integrations/embedding-models/hugging-face-server) | ✓ | ✓ |
| [Jina AI](../../integrations/embedding-models/jina-ai) | ✓ | ✓ |
| [Cloudflare Workers AI](../../integrations/embedding-models/cloudflare-workers-ai.md) | ✓ | ✓ |

We welcome pull requests to add new Embedding Functions to the community.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Chroma provides lightweight wrappers around popular embedding providers, making
| [Jina AI](./embedding-models/jina-ai) | ✓ | ✓ |
| [Roboflow](./embedding-models/roboflow) | ✓ | - |
| [Ollama Embeddings](./embedding-models/ollama) | ✓ | ✓ |
| [Cloudflare Workers AI](./embedding-models/cloudflare-workers-ai.md) | ✓ | ✓ |

---

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
---
id: cloudflare-workers-ai
name: Cloudflare Workers AI
---

# Cloudflare Workers AI

Chroma provides a wrapper around Cloudflare Workers AI embedding models. This embedding function runs remotely against the Cloudflare Workers AI servers, and will require an API key and a Cloudflare account. You can find more information in the [Cloudflare Workers AI Docs](https://developers.cloudflare.com/workers-ai/).

You can also optionally use the Cloudflare AI Gateway for a more customized solution by setting a `gateway_id` argument. See the [Cloudflare AI Gateway Docs](https://developers.cloudflare.com/ai-gateway/providers/workersai/) for more info.

{% TabbedCodeBlock %}

{% Tab label="python" %}

```python
from chromadb.utils.embedding_functions import CloudflareWorkersAIEmbeddingFunction

os.environ["CHROMA_CLOUDFLARE_API_KEY"] = "<INSERT API KEY HERE>"

ef = CloudflareWorkersAIEmbeddingFunction(
account_id="bd4502421ad9c8e8931d02a616e6845a",
model_name="@cf/baai/bge-m3",
)
ef(input=["This is my first text to embed", "This is my second document"])
```

{% /Tab %}

{% Tab label="typescript" %}

```typescript
import { JinaEmbeddingFunction } from 'chromadb';

process.env.CHROMA_CLOUDFLARE_API_KEY = "<INSERT API KEY HERE>"

const embedder = new CloudflareWorkersAIEmbeddingFunction({
account_id="bd4502421ad9c8e8931d02a616e6845a",
model_name="@cf/baai/bge-m3",
});

// use directly
embedder.generate(['This is my first text to embed', 'This is my second document']);
```

{% /Tab %}

{% /TabbedCodeBlock %}

You must pass in an `account_id` and `model_name` to the embedding function. It is recommended to set the `CHROMA_CLOUDFLARE_API_KEY` for the api key, but the embedding function also optionally takes in an `api_key` variable.
Loading
Loading