@@ -175,39 +175,29 @@ impl ZipfCache {
175
175
}
176
176
}
177
177
178
- ////////////////////////////////////////////// client / /////////////////////////////////////////////
178
+ ////////////////////////////////////////////// Connection /////////////////////////////////////////////
179
179
180
- /// Instantiate a new Chroma client. This will use the CHROMA_HOST environment variable (or
181
- /// http://localhost:8000 when unset) as the argument to [client_for_url].
182
- pub async fn client ( ) -> ChromaClient {
183
- let url = std :: env :: var ( "CHROMA_HOST" ) . unwrap_or_else ( |_| "http://localhost:8000" . into ( ) ) ;
184
- client_for_url ( url ) . await
180
+ # [ derive ( Clone , Debug , serde :: Deserialize , serde :: Serialize ) ]
181
+ pub struct Connection {
182
+ pub url : String ,
183
+ pub api_key : String ,
184
+ pub database : String ,
185
185
}
186
186
187
- /// Create a new Chroma client for the given URL. This will use the CHROMA_TOKEN environment
188
- /// variable if set, or no authentication if unset.
189
- pub async fn client_for_url ( url : String ) -> ChromaClient {
190
- if let Ok ( auth) = std:: env:: var ( "CHROMA_TOKEN" ) {
191
- let db = std:: env:: var ( "CHROMA_DATABASE" ) . unwrap_or_else ( |_| "hf-tiny-stories" . into ( ) ) ;
192
- ChromaClient :: new ( ChromaClientOptions {
193
- url : Some ( url) ,
194
- auth : ChromaAuthMethod :: TokenAuth {
195
- token : auth,
196
- header : ChromaTokenHeader :: XChromaToken ,
197
- } ,
198
- database : db,
199
- } )
200
- . await
201
- . unwrap ( )
202
- } else {
203
- ChromaClient :: new ( ChromaClientOptions {
204
- url : Some ( url) ,
205
- auth : ChromaAuthMethod :: None ,
206
- database : "default_database" . to_string ( ) ,
207
- } )
208
- . await
209
- . unwrap ( )
210
- }
187
+ ////////////////////////////////////////////// client //////////////////////////////////////////////
188
+
189
+ /// Instantiate a new Chroma client.
190
+ pub async fn client ( connection : Connection ) -> ChromaClient {
191
+ ChromaClient :: new ( ChromaClientOptions {
192
+ url : Some ( connection. url . clone ( ) ) ,
193
+ auth : ChromaAuthMethod :: TokenAuth {
194
+ token : connection. api_key . clone ( ) ,
195
+ header : ChromaTokenHeader :: XChromaToken ,
196
+ } ,
197
+ database : connection. database . clone ( ) ,
198
+ } )
199
+ . await
200
+ . unwrap ( )
211
201
}
212
202
213
203
////////////////////////////////////////////// DataSet /////////////////////////////////////////////
@@ -932,14 +922,15 @@ impl PartialEq for Throughput {
932
922
933
923
////////////////////////////////////////// RunningWorkload /////////////////////////////////////////
934
924
935
- /// A running workload is a workload that has been bound to a data set at a given throughput. It
936
- /// is assigned a name, uuid, and expiration time.
925
+ /// A running workload is a workload that has been bound to a data set and connection at a given
926
+ /// throughput. It is assigned a name, uuid, and expiration time.
937
927
#[ derive( Clone , Debug ) ]
938
928
pub struct RunningWorkload {
939
929
uuid : Uuid ,
940
930
name : String ,
941
931
workload : Workload ,
942
932
data_set : Arc < dyn DataSet > ,
933
+ connection : Connection ,
943
934
expires : chrono:: DateTime < chrono:: FixedOffset > ,
944
935
throughput : Throughput ,
945
936
}
@@ -959,6 +950,7 @@ impl From<WorkloadSummary> for Option<RunningWorkload> {
959
950
name : s. name ,
960
951
workload : s. workload ,
961
952
data_set,
953
+ connection : s. connection ,
962
954
expires : s. expires ,
963
955
throughput : s. throughput ,
964
956
} )
@@ -994,6 +986,8 @@ pub struct WorkloadSummary {
994
986
pub workload : Workload ,
995
987
/// The data set the workload is bound to.
996
988
pub data_set : serde_json:: Value ,
989
+ /// The connection to use.
990
+ pub connection : Connection ,
997
991
/// The expiration time of the workload.
998
992
pub expires : chrono:: DateTime < chrono:: FixedOffset > ,
999
993
/// The throughput of the workload.
@@ -1007,6 +1001,7 @@ impl From<RunningWorkload> for WorkloadSummary {
1007
1001
name : r. name ,
1008
1002
workload : r. workload ,
1009
1003
data_set : r. data_set . json ( ) ,
1004
+ connection : r. connection ,
1010
1005
expires : r. expires ,
1011
1006
throughput : r. throughput ,
1012
1007
}
@@ -1041,6 +1036,7 @@ impl LoadHarness {
1041
1036
name : String ,
1042
1037
workload : Workload ,
1043
1038
data_set : & Arc < dyn DataSet > ,
1039
+ connection : Connection ,
1044
1040
expires : chrono:: DateTime < chrono:: FixedOffset > ,
1045
1041
throughput : Throughput ,
1046
1042
) -> Uuid {
@@ -1051,6 +1047,7 @@ impl LoadHarness {
1051
1047
name,
1052
1048
workload,
1053
1049
data_set,
1050
+ connection,
1054
1051
expires,
1055
1052
throughput,
1056
1053
} ) ;
@@ -1180,8 +1177,9 @@ impl LoadService {
1180
1177
pub fn start (
1181
1178
& self ,
1182
1179
name : String ,
1183
- data_set : String ,
1184
1180
mut workload : Workload ,
1181
+ data_set : String ,
1182
+ connection : Connection ,
1185
1183
expires : chrono:: DateTime < chrono:: FixedOffset > ,
1186
1184
throughput : Throughput ,
1187
1185
) -> Result < Uuid , Error > {
@@ -1192,7 +1190,14 @@ impl LoadService {
1192
1190
let res = {
1193
1191
// SAFETY(rescrv): Mutex poisoning.
1194
1192
let mut harness = self . harness . lock ( ) . unwrap ( ) ;
1195
- Ok ( harness. start ( name, workload. clone ( ) , data_set, expires, throughput) )
1193
+ Ok ( harness. start (
1194
+ name,
1195
+ workload. clone ( ) ,
1196
+ data_set,
1197
+ connection,
1198
+ expires,
1199
+ throughput,
1200
+ ) )
1196
1201
} ;
1197
1202
self . save_persistent ( ) ?;
1198
1203
res
@@ -1274,7 +1279,7 @@ impl LoadService {
1274
1279
inhibit : Arc < AtomicBool > ,
1275
1280
spec : RunningWorkload ,
1276
1281
) {
1277
- let client = Arc :: new ( client ( ) . await ) ;
1282
+ let client = Arc :: new ( client ( spec . connection . clone ( ) ) . await ) ;
1278
1283
let mut guac = Guacamole :: new ( spec. expires . timestamp_millis ( ) as u64 ) ;
1279
1284
let mut next_op = Instant :: now ( ) ;
1280
1285
let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1000 ) ;
@@ -1627,8 +1632,9 @@ async fn start(
1627
1632
. map_err ( |err| Error :: InvalidRequest ( format ! ( "could not parse rfc3339: {err:?}" ) ) ) ?;
1628
1633
let uuid = state. load . start (
1629
1634
req. name ,
1630
- req. data_set ,
1631
1635
req. workload ,
1636
+ req. data_set ,
1637
+ req. connection ,
1632
1638
expires,
1633
1639
req. throughput ,
1634
1640
) ?;
@@ -1726,8 +1732,13 @@ mod tests {
1726
1732
let load = LoadService :: default ( ) ;
1727
1733
load. start (
1728
1734
"foo" . to_string ( ) ,
1729
- "nop" . to_string ( ) ,
1730
1735
Workload :: ByName ( "get-no-filter" . to_string ( ) ) ,
1736
+ "nop" . to_string ( ) ,
1737
+ Connection {
1738
+ url : "http://localhost:8000" . to_string ( ) ,
1739
+ api_key : "" . to_string ( ) ,
1740
+ database : "" . to_string ( ) ,
1741
+ } ,
1731
1742
( chrono:: Utc :: now ( ) + chrono:: Duration :: seconds ( 10 ) ) . into ( ) ,
1732
1743
Throughput :: Constant ( 1.0 ) ,
1733
1744
)
@@ -1833,8 +1844,13 @@ mod tests {
1833
1844
. unwrap ( ) ;
1834
1845
load. start (
1835
1846
"foo" . to_string ( ) ,
1836
- "nop" . to_string ( ) ,
1837
1847
Workload :: ByName ( "get-no-filter" . to_string ( ) ) ,
1848
+ "nop" . to_string ( ) ,
1849
+ Connection {
1850
+ url : "http://localhost:8000" . to_string ( ) ,
1851
+ api_key : "" . to_string ( ) ,
1852
+ database : "" . to_string ( ) ,
1853
+ } ,
1838
1854
( chrono:: Utc :: now ( ) + chrono:: Duration :: seconds ( 10 ) ) . into ( ) ,
1839
1855
Throughput :: Constant ( 1.0 ) ,
1840
1856
)
0 commit comments