Skip to content

Commit 321fd09

Browse files
[ENH] Make connection configurable per running load-service workload (chroma-core#4374)
1 parent b39d2e4 commit 321fd09

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

rust/load/src/lib.rs

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -175,39 +175,29 @@ impl ZipfCache {
175175
}
176176
}
177177

178-
////////////////////////////////////////////// client //////////////////////////////////////////////
178+
////////////////////////////////////////////// Connection /////////////////////////////////////////////
179179

180-
/// Instantiate a new Chroma client. This will use the CHROMA_HOST environment variable (or
181-
/// http://localhost:8000 when unset) as the argument to [client_for_url].
182-
pub async fn client() -> ChromaClient {
183-
let url = std::env::var("CHROMA_HOST").unwrap_or_else(|_| "http://localhost:8000".into());
184-
client_for_url(url).await
180+
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
181+
pub struct Connection {
182+
pub url: String,
183+
pub api_key: String,
184+
pub database: String,
185185
}
186186

187-
/// Create a new Chroma client for the given URL. This will use the CHROMA_TOKEN environment
188-
/// variable if set, or no authentication if unset.
189-
pub async fn client_for_url(url: String) -> ChromaClient {
190-
if let Ok(auth) = std::env::var("CHROMA_TOKEN") {
191-
let db = std::env::var("CHROMA_DATABASE").unwrap_or_else(|_| "hf-tiny-stories".into());
192-
ChromaClient::new(ChromaClientOptions {
193-
url: Some(url),
194-
auth: ChromaAuthMethod::TokenAuth {
195-
token: auth,
196-
header: ChromaTokenHeader::XChromaToken,
197-
},
198-
database: db,
199-
})
200-
.await
201-
.unwrap()
202-
} else {
203-
ChromaClient::new(ChromaClientOptions {
204-
url: Some(url),
205-
auth: ChromaAuthMethod::None,
206-
database: "default_database".to_string(),
207-
})
208-
.await
209-
.unwrap()
210-
}
187+
////////////////////////////////////////////// client //////////////////////////////////////////////
188+
189+
/// Instantiate a new Chroma client.
190+
pub async fn client(connection: Connection) -> ChromaClient {
191+
ChromaClient::new(ChromaClientOptions {
192+
url: Some(connection.url.clone()),
193+
auth: ChromaAuthMethod::TokenAuth {
194+
token: connection.api_key.clone(),
195+
header: ChromaTokenHeader::XChromaToken,
196+
},
197+
database: connection.database.clone(),
198+
})
199+
.await
200+
.unwrap()
211201
}
212202

213203
////////////////////////////////////////////// DataSet /////////////////////////////////////////////
@@ -932,14 +922,15 @@ impl PartialEq for Throughput {
932922

933923
////////////////////////////////////////// RunningWorkload /////////////////////////////////////////
934924

935-
/// A running workload is a workload that has been bound to a data set at a given throughput. It
936-
/// is assigned a name, uuid, and expiration time.
925+
/// A running workload is a workload that has been bound to a data set and connection at a given
926+
/// throughput. It is assigned a name, uuid, and expiration time.
937927
#[derive(Clone, Debug)]
938928
pub struct RunningWorkload {
939929
uuid: Uuid,
940930
name: String,
941931
workload: Workload,
942932
data_set: Arc<dyn DataSet>,
933+
connection: Connection,
943934
expires: chrono::DateTime<chrono::FixedOffset>,
944935
throughput: Throughput,
945936
}
@@ -959,6 +950,7 @@ impl From<WorkloadSummary> for Option<RunningWorkload> {
959950
name: s.name,
960951
workload: s.workload,
961952
data_set,
953+
connection: s.connection,
962954
expires: s.expires,
963955
throughput: s.throughput,
964956
})
@@ -994,6 +986,8 @@ pub struct WorkloadSummary {
994986
pub workload: Workload,
995987
/// The data set the workload is bound to.
996988
pub data_set: serde_json::Value,
989+
/// The connection to use.
990+
pub connection: Connection,
997991
/// The expiration time of the workload.
998992
pub expires: chrono::DateTime<chrono::FixedOffset>,
999993
/// The throughput of the workload.
@@ -1007,6 +1001,7 @@ impl From<RunningWorkload> for WorkloadSummary {
10071001
name: r.name,
10081002
workload: r.workload,
10091003
data_set: r.data_set.json(),
1004+
connection: r.connection,
10101005
expires: r.expires,
10111006
throughput: r.throughput,
10121007
}
@@ -1041,6 +1036,7 @@ impl LoadHarness {
10411036
name: String,
10421037
workload: Workload,
10431038
data_set: &Arc<dyn DataSet>,
1039+
connection: Connection,
10441040
expires: chrono::DateTime<chrono::FixedOffset>,
10451041
throughput: Throughput,
10461042
) -> Uuid {
@@ -1051,6 +1047,7 @@ impl LoadHarness {
10511047
name,
10521048
workload,
10531049
data_set,
1050+
connection,
10541051
expires,
10551052
throughput,
10561053
});
@@ -1180,8 +1177,9 @@ impl LoadService {
11801177
pub fn start(
11811178
&self,
11821179
name: String,
1183-
data_set: String,
11841180
mut workload: Workload,
1181+
data_set: String,
1182+
connection: Connection,
11851183
expires: chrono::DateTime<chrono::FixedOffset>,
11861184
throughput: Throughput,
11871185
) -> Result<Uuid, Error> {
@@ -1192,7 +1190,14 @@ impl LoadService {
11921190
let res = {
11931191
// SAFETY(rescrv): Mutex poisoning.
11941192
let mut harness = self.harness.lock().unwrap();
1195-
Ok(harness.start(name, workload.clone(), data_set, expires, throughput))
1193+
Ok(harness.start(
1194+
name,
1195+
workload.clone(),
1196+
data_set,
1197+
connection,
1198+
expires,
1199+
throughput,
1200+
))
11961201
};
11971202
self.save_persistent()?;
11981203
res
@@ -1274,7 +1279,7 @@ impl LoadService {
12741279
inhibit: Arc<AtomicBool>,
12751280
spec: RunningWorkload,
12761281
) {
1277-
let client = Arc::new(client().await);
1282+
let client = Arc::new(client(spec.connection.clone()).await);
12781283
let mut guac = Guacamole::new(spec.expires.timestamp_millis() as u64);
12791284
let mut next_op = Instant::now();
12801285
let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
@@ -1627,8 +1632,9 @@ async fn start(
16271632
.map_err(|err| Error::InvalidRequest(format!("could not parse rfc3339: {err:?}")))?;
16281633
let uuid = state.load.start(
16291634
req.name,
1630-
req.data_set,
16311635
req.workload,
1636+
req.data_set,
1637+
req.connection,
16321638
expires,
16331639
req.throughput,
16341640
)?;
@@ -1726,8 +1732,13 @@ mod tests {
17261732
let load = LoadService::default();
17271733
load.start(
17281734
"foo".to_string(),
1729-
"nop".to_string(),
17301735
Workload::ByName("get-no-filter".to_string()),
1736+
"nop".to_string(),
1737+
Connection {
1738+
url: "http://localhost:8000".to_string(),
1739+
api_key: "".to_string(),
1740+
database: "".to_string(),
1741+
},
17311742
(chrono::Utc::now() + chrono::Duration::seconds(10)).into(),
17321743
Throughput::Constant(1.0),
17331744
)
@@ -1833,8 +1844,13 @@ mod tests {
18331844
.unwrap();
18341845
load.start(
18351846
"foo".to_string(),
1836-
"nop".to_string(),
18371847
Workload::ByName("get-no-filter".to_string()),
1848+
"nop".to_string(),
1849+
Connection {
1850+
url: "http://localhost:8000".to_string(),
1851+
api_key: "".to_string(),
1852+
database: "".to_string(),
1853+
},
18381854
(chrono::Utc::now() + chrono::Duration::seconds(10)).into(),
18391855
Throughput::Constant(1.0),
18401856
)

0 commit comments

Comments
 (0)