Skip to content

Commit 65eb9b9

Browse files
prose test, incorporate more API changes
1 parent 209f9f7 commit 65eb9b9

File tree

11 files changed

+728
-570
lines changed

11 files changed

+728
-570
lines changed

src/action/csfle/create_data_key.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ impl ClientEncryption {
88
/// `await` will return d[`Result<Binary>`] (subtype 0x04) with the _id of the created
99
/// document as a UUID.
1010
#[deeplink]
11-
pub fn create_data_key(&self, master_key: MasterKey) -> CreateDataKey {
11+
pub fn create_data_key(&self, master_key: impl Into<MasterKey>) -> CreateDataKey {
1212
CreateDataKey {
1313
client_enc: self,
14-
master_key,
14+
master_key: master_key.into(),
1515
options: None,
1616
#[cfg(test)]
1717
test_kms_provider: None,

src/client/csfle/client_encryption.rs

Lines changed: 159 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod encrypt;
55

66
use mongocrypt::{ctx::KmsProvider, Crypt};
77
use serde::{Deserialize, Serialize};
8+
use typed_builder::TypedBuilder;
89

910
use crate::{
1011
bson::{doc, spec::BinarySubtype, Binary, RawBinaryRef, RawDocumentBuf},
@@ -193,57 +194,169 @@ impl ClientEncryption {
193194
#[non_exhaustive]
194195
#[allow(missing_docs)]
195196
pub enum MasterKey {
196-
#[serde(rename_all = "camelCase")]
197-
Aws {
198-
region: String,
199-
/// The Amazon Resource Name (ARN) to the AWS customer master key (CMK).
200-
key: String,
201-
/// An alternate host identifier to send KMS requests to. May include port number. Defaults
202-
/// to "kms.REGION.amazonaws.com"
203-
endpoint: Option<String>,
204-
},
205-
#[serde(rename_all = "camelCase")]
206-
Azure {
207-
/// Host with optional port. Example: "example.vault.azure.net".
208-
key_vault_endpoint: String,
209-
key_name: String,
210-
/// A specific version of the named key, defaults to using the key's primary version.
211-
key_version: Option<String>,
212-
},
213-
#[serde(rename_all = "camelCase")]
214-
Gcp {
215-
project_id: String,
216-
location: String,
217-
key_ring: String,
218-
key_name: String,
219-
/// A specific version of the named key, defaults to using the key's primary version.
220-
key_version: Option<String>,
221-
/// Host with optional port. Defaults to "cloudkms.googleapis.com".
222-
endpoint: Option<String>,
223-
},
224-
/// Master keys are not applicable to `KmsProvider::Local`.
225-
Local,
226-
#[serde(rename_all = "camelCase")]
227-
Kmip {
228-
/// keyId is the KMIP Unique Identifier to a 96 byte KMIP Secret Data managed object. If
229-
/// keyId is omitted, the driver creates a random 96 byte KMIP Secret Data managed object.
230-
key_id: Option<String>,
231-
/// If true (recommended), the KMIP server must decrypt this key. Defaults to false.
232-
delegated: Option<bool>,
233-
/// Host with optional port.
234-
endpoint: Option<String>,
235-
},
197+
Aws(AwsMasterKey),
198+
Azure(AzureMasterKey),
199+
Gcp(GcpMasterKey),
200+
Kmip(KmipMasterKey),
201+
Local(LocalMasterKey),
202+
}
203+
204+
/// An AWS master key.
205+
#[serde_with::skip_serializing_none]
206+
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
207+
#[builder(field_defaults(default, setter(into)))]
208+
#[serde(rename_all = "camelCase")]
209+
#[non_exhaustive]
210+
pub struct AwsMasterKey {
211+
/// The name for the key. The value for this field must be the same as the corresponding
212+
/// [`KmsProvider`](mongocrypt::ctx::KmsProvider)'s name.
213+
#[serde(skip)]
214+
pub name: Option<String>,
215+
216+
/// The region.
217+
pub region: String,
218+
219+
/// The Amazon Resource Name (ARN) to the AWS customer master key (CMK).
220+
pub key: String,
221+
222+
/// An alternate host identifier to send KMS requests to. May include port number. Defaults to
223+
/// "kms.<region>.amazonaws.com".
224+
pub endpoint: Option<String>,
225+
}
226+
227+
impl From<AwsMasterKey> for MasterKey {
228+
fn from(aws_master_key: AwsMasterKey) -> Self {
229+
Self::Aws(aws_master_key)
230+
}
231+
}
232+
233+
/// An Azure master key.
234+
#[serde_with::skip_serializing_none]
235+
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
236+
#[builder(field_defaults(default, setter(into)))]
237+
#[serde(rename_all = "camelCase")]
238+
#[non_exhaustive]
239+
pub struct AzureMasterKey {
240+
/// The name for the key. The value for this field must be the same as the corresponding
241+
/// [`KmsProvider`](mongocrypt::ctx::KmsProvider)'s name.
242+
#[serde(skip)]
243+
pub name: Option<String>,
244+
245+
/// Host with optional port. Example: "example.vault.azure.net".
246+
pub key_vault_endpoint: String,
247+
248+
/// The key name.
249+
pub key_name: String,
250+
251+
/// A specific version of the named key, defaults to using the key's primary version.
252+
pub key_version: Option<String>,
253+
}
254+
255+
impl From<AzureMasterKey> for MasterKey {
256+
fn from(azure_master_key: AzureMasterKey) -> Self {
257+
Self::Azure(azure_master_key)
258+
}
259+
}
260+
261+
/// A GCP master key.
262+
#[serde_with::skip_serializing_none]
263+
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
264+
#[builder(field_defaults(default, setter(into)))]
265+
#[serde(rename_all = "camelCase")]
266+
#[non_exhaustive]
267+
pub struct GcpMasterKey {
268+
/// The name for the key. The value for this field must be the same as the corresponding
269+
/// [`KmsProvider`](mongocrypt::ctx::KmsProvider)'s name.
270+
#[serde(skip)]
271+
pub name: Option<String>,
272+
273+
/// The project ID.
274+
pub project_id: String,
275+
276+
/// The location.
277+
pub location: String,
278+
279+
/// The key ring.
280+
pub key_ring: String,
281+
282+
/// The key name.
283+
pub key_name: String,
284+
285+
/// A specific version of the named key. Defaults to using the key's primary version.
286+
pub key_version: Option<String>,
287+
288+
/// Host with optional port. Defaults to "cloudkms.googleapis.com".
289+
pub endpoint: Option<String>,
290+
}
291+
292+
impl From<GcpMasterKey> for MasterKey {
293+
fn from(gcp_master_key: GcpMasterKey) -> Self {
294+
Self::Gcp(gcp_master_key)
295+
}
296+
}
297+
298+
/// A local master key.
299+
#[serde_with::skip_serializing_none]
300+
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
301+
#[builder(field_defaults(default, setter(into)))]
302+
#[serde(rename_all = "camelCase")]
303+
#[non_exhaustive]
304+
pub struct LocalMasterKey {
305+
/// The name for the key. The value for this field must be the same as the corresponding
306+
/// [`KmsProvider`](mongocrypt::ctx::KmsProvider)'s name.
307+
#[serde(skip)]
308+
pub name: Option<String>,
309+
}
310+
311+
impl From<LocalMasterKey> for MasterKey {
312+
fn from(local_master_key: LocalMasterKey) -> Self {
313+
Self::Local(local_master_key)
314+
}
315+
}
316+
317+
/// A KMIP master key.
318+
#[serde_with::skip_serializing_none]
319+
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
320+
#[builder(field_defaults(default, setter(into)))]
321+
#[serde(rename_all = "camelCase")]
322+
#[non_exhaustive]
323+
pub struct KmipMasterKey {
324+
/// The name for the key. The value for this field must be the same as the corresponding
325+
/// [`KmsProvider`](mongocrypt::ctx::KmsProvider)'s name.
326+
#[serde(skip)]
327+
pub name: Option<String>,
328+
329+
/// The KMIP Unique Identifier to a 96 byte KMIP Secret Data managed object. If this field is
330+
/// not specified, the driver creates a random 96 byte KMIP Secret Data managed object.
331+
pub key_id: Option<String>,
332+
333+
/// If true (recommended), the KMIP server must decrypt this key. Defaults to false.
334+
pub delegated: Option<bool>,
335+
336+
/// Host with optional port.
337+
pub endpoint: Option<String>,
338+
}
339+
340+
impl From<KmipMasterKey> for MasterKey {
341+
fn from(kmip_master_key: KmipMasterKey) -> Self {
342+
Self::Kmip(kmip_master_key)
343+
}
236344
}
237345

238346
impl MasterKey {
239347
/// Returns the `KmsProvider` associated with this key.
240348
pub fn provider(&self) -> KmsProvider {
241-
match self {
242-
MasterKey::Aws { .. } => KmsProvider::Aws { name: None },
243-
MasterKey::Azure { .. } => KmsProvider::Azure { name: None },
244-
MasterKey::Gcp { .. } => KmsProvider::Gcp { name: None },
245-
MasterKey::Kmip { .. } => KmsProvider::Kmip { name: None },
246-
MasterKey::Local => KmsProvider::Local { name: None },
349+
let (provider, name) = match self {
350+
MasterKey::Aws(AwsMasterKey { name, .. }) => (KmsProvider::aws(), name.clone()),
351+
MasterKey::Azure(AzureMasterKey { name, .. }) => (KmsProvider::azure(), name.clone()),
352+
MasterKey::Gcp(GcpMasterKey { name, .. }) => (KmsProvider::gcp(), name.clone()),
353+
MasterKey::Kmip(KmipMasterKey { name, .. }) => (KmsProvider::kmip(), name.clone()),
354+
MasterKey::Local(LocalMasterKey { name, .. }) => (KmsProvider::local(), name.clone()),
355+
};
356+
if let Some(name) = name {
357+
provider.with_name(name)
358+
} else {
359+
provider
247360
}
248361
}
249362
}

src/client/csfle/client_encryption/create_data_key.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ impl ClientEncryption {
4242
opts: Option<DataKeyOptions>,
4343
) -> Result<Ctx> {
4444
let mut builder = self.crypt.ctx_builder();
45-
let mut key_doc = doc! { "provider": kms_provider.name() };
46-
if !matches!(master_key, MasterKey::Local) {
45+
let mut key_doc = doc! { "provider": kms_provider.as_string() };
46+
if !matches!(master_key, MasterKey::Local { .. }) {
4747
let master_doc = bson::to_document(&master_key)?;
4848
key_doc.extend(master_doc);
4949
}

src/client/csfle/options.rs

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,28 +132,32 @@ impl KmsProviders {
132132

133133
#[cfg(test)]
134134
pub(crate) fn set_test_options(&mut self) {
135-
use crate::{bson::doc, test::KMS_PROVIDERS_MAP};
136-
137-
for (provider, credentials) in self.credentials.iter_mut().filter(|(p, _)| {
138-
matches!(
139-
p,
140-
KmsProvider::Aws { .. }
141-
| KmsProvider::Azure { .. }
142-
| KmsProvider::Gcp { .. }
143-
| KmsProvider::Kmip { .. }
144-
)
145-
}) {
146-
let (test_credentials, tls) = KMS_PROVIDERS_MAP.get(provider).unwrap().clone();
147-
*credentials = test_credentials;
135+
use mongocrypt::ctx::KmsProviderType;
136+
137+
use crate::{bson::doc, test::csfle::ALL_KMS_PROVIDERS};
138+
139+
let all_kms_providers = ALL_KMS_PROVIDERS.clone();
140+
let mut aws_tls_options = None;
141+
for (provider, test_credentials, tls_options) in all_kms_providers {
142+
let Some(credentials) = self.credentials.get_mut(&provider) else {
143+
continue;
144+
};
145+
if !matches!(provider.provider_type(), KmsProviderType::Local) {
146+
*credentials = test_credentials;
147+
}
148+
149+
if let Some(tls_options) = tls_options {
150+
if matches!(provider.provider_type(), KmsProviderType::Aws) {
151+
aws_tls_options = Some(tls_options.clone());
152+
}
148153

149-
if let Some(tls) = tls {
150154
self.tls_options
151155
.get_or_insert_with(KmsProvidersTlsOptions::new)
152-
.insert(provider.clone(), tls);
156+
.insert(provider.clone(), tls_options);
153157
}
154158
}
155159

156-
let aws_temp_provider = KmsProvider::Other("awsTemporary".to_string());
160+
let aws_temp_provider = KmsProvider::other("awsTemporary".to_string());
157161
if self.credentials.contains_key(&aws_temp_provider) {
158162
let aws_credentials = doc! {
159163
"accessKeyId": std::env::var("CSFLE_AWS_TEMP_ACCESS_KEY_ID").unwrap(),
@@ -162,20 +166,16 @@ impl KmsProviders {
162166
};
163167
self.credentials.insert(KmsProvider::aws(), aws_credentials);
164168

165-
let aws_tls = KMS_PROVIDERS_MAP
166-
.get(&KmsProvider::aws())
167-
.and_then(|(_, t)| t.as_ref());
168-
if let Some(tls) = aws_tls {
169+
if let Some(ref aws_tls_options) = aws_tls_options {
169170
self.tls_options
170171
.get_or_insert_with(KmsProvidersTlsOptions::new)
171-
.insert(KmsProvider::aws(), tls.clone());
172+
.insert(KmsProvider::aws(), aws_tls_options.clone());
172173
}
173174

174175
self.clear(&aws_temp_provider);
175176
}
176177

177-
let aws_temp_no_session_token_provider =
178-
KmsProvider::Other("awsTemporaryNoSessionToken".to_string());
178+
let aws_temp_no_session_token_provider = KmsProvider::other("awsTemporaryNoSessionToken");
179179
if self
180180
.credentials
181181
.contains_key(&aws_temp_no_session_token_provider)
@@ -186,14 +186,12 @@ impl KmsProviders {
186186
};
187187
self.credentials.insert(KmsProvider::aws(), aws_credentials);
188188

189-
let aws_tls = KMS_PROVIDERS_MAP
190-
.get(&KmsProvider::aws())
191-
.and_then(|(_, t)| t.as_ref());
192-
if let Some(tls) = aws_tls {
189+
if let Some(aws_tls_options) = aws_tls_options {
193190
self.tls_options
194191
.get_or_insert_with(KmsProvidersTlsOptions::new)
195-
.insert(KmsProvider::aws(), tls.clone());
192+
.insert(KmsProvider::aws(), aws_tls_options);
196193
}
194+
197195
self.clear(&aws_temp_no_session_token_provider);
198196
}
199197
}

src/client/csfle/state_machine.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66

77
use bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
88
use futures_util::{stream, TryStreamExt};
9-
use mongocrypt::ctx::{Ctx, KmsProvider, State};
9+
use mongocrypt::ctx::{Ctx, KmsProviderType, State};
1010
use rayon::ThreadPool;
1111
use tokio::{
1212
io::{AsyncReadExt, AsyncWriteExt},
@@ -216,8 +216,8 @@ impl CryptExecutor {
216216
continue;
217217
}
218218

219-
match provider {
220-
KmsProvider::Aws { .. } => {
219+
match provider.provider_type() {
220+
KmsProviderType::Aws => {
221221
#[cfg(feature = "aws-auth")]
222222
{
223223
use crate::{
@@ -237,7 +237,7 @@ impl CryptExecutor {
237237
if let Some(token) = aws_creds.session_token() {
238238
creds.append("sessionToken", token);
239239
}
240-
kms_providers.append(provider.name(), creds);
240+
kms_providers.append(provider.as_string(), creds);
241241
}
242242
#[cfg(not(feature = "aws-auth"))]
243243
{
@@ -247,11 +247,13 @@ impl CryptExecutor {
247247
));
248248
}
249249
}
250-
KmsProvider::Azure { .. } => {
250+
KmsProviderType::Azure => {
251251
#[cfg(feature = "azure-kms")]
252252
{
253-
kms_providers
254-
.append(provider.name(), self.azure.get_token().await?);
253+
kms_providers.append(
254+
provider.as_string(),
255+
self.azure.get_token().await?,
256+
);
255257
}
256258
#[cfg(not(feature = "azure-kms"))]
257259
{
@@ -261,7 +263,7 @@ impl CryptExecutor {
261263
));
262264
}
263265
}
264-
KmsProvider::Gcp { .. } => {
266+
KmsProviderType::Gcp => {
265267
#[cfg(feature = "gcp-kms")]
266268
{
267269
use crate::runtime::HttpClient;

0 commit comments

Comments
 (0)