8
8
class JinaEmbeddingFunction (EmbeddingFunction [Documents ]):
9
9
"""
10
10
This class is used to get embeddings for a list of texts using the Jina AI API.
11
- It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en ".
11
+ It requires an API key and a model name. The default model name is "jina-embeddings-v3 ".
12
12
"""
13
13
14
14
def __init__ (
15
15
self ,
16
16
api_key : Optional [str ] = None ,
17
- model_name : str = "jina-embeddings-v2-base-en " ,
17
+ model_name : str = "jina-embeddings-v3 " ,
18
18
api_key_env_var : str = "CHROMA_JINA_API_KEY" ,
19
+ task : Optional [str ] = None ,
20
+ late_chunking : Optional [bool ] = None ,
21
+ truncate : Optional [bool ] = None ,
22
+ dimensions : Optional [int ] = None ,
23
+ embedding_type : Optional [str ] = None ,
24
+ normalized : Optional [bool ] = None ,
19
25
):
20
26
"""
21
27
Initialize the JinaEmbeddingFunction.
@@ -24,7 +30,7 @@ def __init__(
24
30
api_key_env_var (str, optional): Environment variable name that contains your API key for the Jina AI API.
25
31
Defaults to "CHROMA_JINA_API_KEY".
26
32
model_name (str, optional): The name of the model to use for text embeddings.
27
- Defaults to "jina-embeddings-v2-base-en ".
33
+ Defaults to "jina-embeddings-v3 ".
28
34
"""
29
35
try :
30
36
import httpx
@@ -40,6 +46,14 @@ def __init__(
40
46
41
47
self .model_name = model_name
42
48
49
+ # Initialize optional attributes to None
50
+ self .task = task
51
+ self .late_chunking = late_chunking
52
+ self .truncate = truncate
53
+ self .dimensions = dimensions
54
+ self .embedding_type = embedding_type
55
+ self .normalized = normalized
56
+
43
57
self ._api_url = "https://api.jina.ai/v1/embeddings"
44
58
self ._session = httpx .Client ()
45
59
self ._session .headers .update (
@@ -51,7 +65,7 @@ def __call__(self, input: Documents) -> Embeddings:
51
65
Get the embeddings for a list of texts.
52
66
53
67
Args:
54
- input (Documents): A list of texts or images to get embeddings for.
68
+ input (Documents): A list of texts to get embeddings for.
55
69
56
70
Returns:
57
71
Embeddings: The embeddings for the texts.
@@ -64,10 +78,31 @@ def __call__(self, input: Documents) -> Embeddings:
64
78
if not all (isinstance (item , str ) for item in input ):
65
79
raise ValueError ("Jina AI only supports text documents, not images" )
66
80
81
+ payload : Dict [str , Any ] = {
82
+ "input" : input ,
83
+ "model" : self .model_name ,
84
+ }
85
+
86
+ if self .task is not None :
87
+ payload ["task" ] = self .task
88
+
89
+ if self .late_chunking is not None :
90
+ payload ["late_chunking" ] = self .late_chunking
91
+
92
+ if self .truncate is not None :
93
+ payload ["truncate" ] = self .truncate
94
+
95
+ if self .dimensions is not None :
96
+ payload ["dimensions" ] = self .dimensions
97
+
98
+ if self .embedding_type is not None :
99
+ payload ["embedding_type" ] = self .embedding_type
100
+
101
+ if self .normalized is not None :
102
+ payload ["normalized" ] = self .normalized
103
+
67
104
# Call Jina AI Embedding API
68
- resp = self ._session .post (
69
- self ._api_url , json = {"input" : input , "model" : self .model_name }
70
- ).json ()
105
+ resp = self ._session .post (self ._api_url , json = payload ).json ()
71
106
72
107
if "data" not in resp :
73
108
raise RuntimeError (resp .get ("detail" , "Unknown error" ))
@@ -97,16 +132,38 @@ def supported_spaces(self) -> List[Space]:
97
132
def build_from_config (config : Dict [str , Any ]) -> "EmbeddingFunction[Documents]" :
98
133
api_key_env_var = config .get ("api_key_env_var" )
99
134
model_name = config .get ("model_name" )
135
+ task = config .get ("task" )
136
+ late_chunking = config .get ("late_chunking" )
137
+ truncate = config .get ("truncate" )
138
+ dimensions = config .get ("dimensions" )
139
+ embedding_type = config .get ("embedding_type" )
140
+ normalized = config .get ("normalized" )
100
141
101
142
if api_key_env_var is None or model_name is None :
102
- assert False , "This code should not be reached"
143
+ assert False , "This code should not be reached" # this is for type checking
103
144
104
145
return JinaEmbeddingFunction (
105
- api_key_env_var = api_key_env_var , model_name = model_name
146
+ api_key_env_var = api_key_env_var ,
147
+ model_name = model_name ,
148
+ task = task ,
149
+ late_chunking = late_chunking ,
150
+ truncate = truncate ,
151
+ dimensions = dimensions ,
152
+ embedding_type = embedding_type ,
153
+ normalized = normalized ,
106
154
)
107
155
108
156
def get_config (self ) -> Dict [str , Any ]:
109
- return {"api_key_env_var" : self .api_key_env_var , "model_name" : self .model_name }
157
+ return {
158
+ "api_key_env_var" : self .api_key_env_var ,
159
+ "model_name" : self .model_name ,
160
+ "task" : self .task ,
161
+ "late_chunking" : self .late_chunking ,
162
+ "truncate" : self .truncate ,
163
+ "dimensions" : self .dimensions ,
164
+ "embedding_type" : self .embedding_type ,
165
+ "normalized" : self .normalized ,
166
+ }
110
167
111
168
def validate_config_update (
112
169
self , old_config : Dict [str , Any ], new_config : Dict [str , Any ]
0 commit comments