Skip to content

Commit 5bcd97d

Browse files
committed
[ENH] Update Jina embedding function to support all models and configurations
1 parent 1addfda commit 5bcd97d

File tree

3 files changed

+174
-14
lines changed

3 files changed

+174
-14
lines changed

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@
88
class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
99
"""
1010
This class is used to get embeddings for a list of texts using the Jina AI API.
11-
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
11+
It requires an API key and a model name. The default model name is "jina-embeddings-v3".
1212
"""
1313

1414
def __init__(
1515
self,
1616
api_key: Optional[str] = None,
17-
model_name: str = "jina-embeddings-v2-base-en",
17+
model_name: str = "jina-embeddings-v3",
1818
api_key_env_var: str = "CHROMA_JINA_API_KEY",
19+
task: Optional[str] = None,
20+
late_chunking: Optional[bool] = None,
21+
truncate: Optional[bool] = None,
22+
dimensions: Optional[int] = None,
23+
embedding_type: Optional[str] = None,
24+
normalized: Optional[bool] = None,
1925
):
2026
"""
2127
Initialize the JinaEmbeddingFunction.
@@ -24,7 +30,7 @@ def __init__(
2430
api_key_env_var (str, optional): Environment variable name that contains your API key for the Jina AI API.
2531
Defaults to "CHROMA_JINA_API_KEY".
2632
model_name (str, optional): The name of the model to use for text embeddings.
27-
Defaults to "jina-embeddings-v2-base-en".
33+
Defaults to "jina-embeddings-v3".
2834
"""
2935
try:
3036
import httpx
@@ -40,6 +46,14 @@ def __init__(
4046

4147
self.model_name = model_name
4248

49+
# Initialize optional attributes to None
50+
self.task = task
51+
self.late_chunking = late_chunking
52+
self.truncate = truncate
53+
self.dimensions = dimensions
54+
self.embedding_type = embedding_type
55+
self.normalized = normalized
56+
4357
self._api_url = "https://api.jina.ai/v1/embeddings"
4458
self._session = httpx.Client()
4559
self._session.headers.update(
@@ -51,7 +65,7 @@ def __call__(self, input: Documents) -> Embeddings:
5165
Get the embeddings for a list of texts.
5266
5367
Args:
54-
input (Documents): A list of texts or images to get embeddings for.
68+
input (Documents): A list of texts to get embeddings for.
5569
5670
Returns:
5771
Embeddings: The embeddings for the texts.
@@ -64,10 +78,31 @@ def __call__(self, input: Documents) -> Embeddings:
6478
if not all(isinstance(item, str) for item in input):
6579
raise ValueError("Jina AI only supports text documents, not images")
6680

81+
payload: Dict[str, Any] = {
82+
"input": input,
83+
"model": self.model_name,
84+
}
85+
86+
if self.task is not None:
87+
payload["task"] = self.task
88+
89+
if self.late_chunking is not None:
90+
payload["late_chunking"] = self.late_chunking
91+
92+
if self.truncate is not None:
93+
payload["truncate"] = self.truncate
94+
95+
if self.dimensions is not None:
96+
payload["dimensions"] = self.dimensions
97+
98+
if self.embedding_type is not None:
99+
payload["embedding_type"] = self.embedding_type
100+
101+
if self.normalized is not None:
102+
payload["normalized"] = self.normalized
103+
67104
# Call Jina AI Embedding API
68-
resp = self._session.post(
69-
self._api_url, json={"input": input, "model": self.model_name}
70-
).json()
105+
resp = self._session.post(self._api_url, json=payload).json()
71106

72107
if "data" not in resp:
73108
raise RuntimeError(resp.get("detail", "Unknown error"))
@@ -97,16 +132,38 @@ def supported_spaces(self) -> List[Space]:
97132
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
98133
api_key_env_var = config.get("api_key_env_var")
99134
model_name = config.get("model_name")
135+
task = config.get("task")
136+
late_chunking = config.get("late_chunking")
137+
truncate = config.get("truncate")
138+
dimensions = config.get("dimensions")
139+
embedding_type = config.get("embedding_type")
140+
normalized = config.get("normalized")
100141

101142
if api_key_env_var is None or model_name is None:
102-
assert False, "This code should not be reached"
143+
assert False, "This code should not be reached" # this is for type checking
103144

104145
return JinaEmbeddingFunction(
105-
api_key_env_var=api_key_env_var, model_name=model_name
146+
api_key_env_var=api_key_env_var,
147+
model_name=model_name,
148+
task=task,
149+
late_chunking=late_chunking,
150+
truncate=truncate,
151+
dimensions=dimensions,
152+
embedding_type=embedding_type,
153+
normalized=normalized,
106154
)
107155

108156
def get_config(self) -> Dict[str, Any]:
109-
return {"api_key_env_var": self.api_key_env_var, "model_name": self.model_name}
157+
return {
158+
"api_key_env_var": self.api_key_env_var,
159+
"model_name": self.model_name,
160+
"task": self.task,
161+
"late_chunking": self.late_chunking,
162+
"truncate": self.truncate,
163+
"dimensions": self.dimensions,
164+
"embedding_type": self.embedding_type,
165+
"normalized": self.normalized,
166+
}
110167

111168
def validate_config_update(
112169
self, old_config: Dict[str, Any], new_config: Dict[str, Any]

clients/js/packages/chromadb-core/src/embeddings/JinaEmbeddingFunction.ts

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,59 @@ import { validateConfigSchema } from "../schemas/schemaUtils";
33
type StoredConfig = {
44
api_key_env_var: string;
55
model_name: string;
6+
task?: string;
7+
late_chunking?: boolean;
8+
truncate?: boolean;
9+
dimensions?: number;
10+
embedding_type?: string;
11+
normalized?: boolean;
612
};
713

14+
interface JinaRequestBody {
15+
input: string[];
16+
model: string;
17+
task?: string;
18+
late_chunking?: boolean;
19+
truncate?: boolean;
20+
dimensions?: number;
21+
embedding_type?: string;
22+
normalized?: boolean;
23+
}
24+
825
export class JinaEmbeddingFunction implements IEmbeddingFunction {
926
name = "jina";
1027

1128
private api_key_env_var: string;
1229
private model_name: string;
1330
private api_url: string;
1431
private headers: { [key: string]: string };
32+
private task: string | undefined;
33+
private late_chunking: boolean | undefined;
34+
private truncate: boolean | undefined;
35+
private dimensions: number | undefined;
36+
private embedding_type: string | undefined;
37+
private normalized: boolean | undefined;
1538

1639
constructor({
1740
jinaai_api_key,
1841
model_name = "jina-embeddings-v2-base-en",
1942
api_key_env_var = "JINAAI_API_KEY",
43+
task,
44+
late_chunking,
45+
truncate,
46+
dimensions,
47+
embedding_type,
48+
normalized,
2049
}: {
2150
jinaai_api_key?: string;
2251
model_name?: string;
2352
api_key_env_var: string;
53+
task?: string;
54+
late_chunking?: boolean;
55+
truncate?: boolean;
56+
dimensions?: number;
57+
embedding_type?: string;
58+
normalized?: boolean;
2459
}) {
2560
const apiKey = jinaai_api_key ?? process.env[api_key_env_var];
2661
if (!apiKey) {
@@ -31,6 +66,12 @@ export class JinaEmbeddingFunction implements IEmbeddingFunction {
3166

3267
this.model_name = model_name;
3368
this.api_key_env_var = api_key_env_var;
69+
this.task = task;
70+
this.late_chunking = late_chunking;
71+
this.truncate = truncate;
72+
this.dimensions = dimensions;
73+
this.embedding_type = embedding_type;
74+
this.normalized = normalized;
3475

3576
this.api_url = "https://api.jina.ai/v1/embeddings";
3677
this.headers = {
@@ -41,14 +82,40 @@ export class JinaEmbeddingFunction implements IEmbeddingFunction {
4182
}
4283

4384
public async generate(texts: string[]) {
85+
let json_body: JinaRequestBody = {
86+
input: texts,
87+
model: this.model_name,
88+
};
89+
90+
if (this.task) {
91+
json_body.task = this.task;
92+
}
93+
94+
if (this.late_chunking) {
95+
json_body.late_chunking = this.late_chunking;
96+
}
97+
98+
if (this.truncate) {
99+
json_body.truncate = this.truncate;
100+
}
101+
102+
if (this.dimensions) {
103+
json_body.dimensions = this.dimensions;
104+
}
105+
106+
if (this.embedding_type) {
107+
json_body.embedding_type = this.embedding_type;
108+
}
109+
110+
if (this.normalized) {
111+
json_body.normalized = this.normalized;
112+
}
113+
44114
try {
45115
const response = await fetch(this.api_url, {
46116
method: "POST",
47117
headers: this.headers,
48-
body: JSON.stringify({
49-
input: texts,
50-
model: this.model_name,
51-
}),
118+
body: JSON.stringify(json_body),
52119
});
53120

54121
const data = (await response.json()) as { data: any[]; detail: string };
@@ -73,13 +140,25 @@ export class JinaEmbeddingFunction implements IEmbeddingFunction {
73140
return new JinaEmbeddingFunction({
74141
model_name: config.model_name,
75142
api_key_env_var: config.api_key_env_var,
143+
task: config.task,
144+
late_chunking: config.late_chunking,
145+
truncate: config.truncate,
146+
dimensions: config.dimensions,
147+
embedding_type: config.embedding_type,
148+
normalized: config.normalized,
76149
});
77150
}
78151

79152
getConfig(): StoredConfig {
80153
return {
81154
api_key_env_var: this.api_key_env_var,
82155
model_name: this.model_name,
156+
task: this.task,
157+
late_chunking: this.late_chunking,
158+
truncate: this.truncate,
159+
dimensions: this.dimensions,
160+
embedding_type: this.embedding_type,
161+
normalized: this.normalized,
83162
};
84163
}
85164

schemas/embedding_functions/jina.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,30 @@
1212
"api_key_env_var": {
1313
"type": "string",
1414
"description": "Parameter api_key_env_var for the jina embedding function"
15+
},
16+
"task": {
17+
"type": "string",
18+
"description": "Parameter task for the jina embedding function"
19+
},
20+
"late_chunking": {
21+
"type": "boolean",
22+
"description": "Parameter late_chunking for the jina embedding function"
23+
},
24+
"truncate": {
25+
"type": "boolean",
26+
"description": "Parameter truncate for the jina embedding function"
27+
},
28+
"dimensions": {
29+
"type": "integer",
30+
"description": "Parameter dimensions for the jina embedding function"
31+
},
32+
"embedding_type": {
33+
"type": "string",
34+
"description": "Parameter embedding_type for the jina embedding function"
35+
},
36+
"normalized": {
37+
"type": "boolean",
38+
"description": "Parameter normalized for the jina embedding function"
1539
}
1640
},
1741
"required": [

0 commit comments

Comments
 (0)