Skip to content

Commit 2174e33

Browse files
committed
WIP Make Oauth2 Client & Config Pluggable
So test dependencies can be injected into the Google integration. Doing this statically with generics failed because of rust-lang/rust#100013. Refactor to used boxed vtable-dispatching types. It's much less elegant.
1 parent 053bcab commit 2174e33

File tree

15 files changed

+533
-303
lines changed

15 files changed

+533
-303
lines changed

cloud_scraper/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ x509-parser = "0.16.0"
3939
[dev-dependencies]
4040
mockall = "0.13.0"
4141
once_cell = "1.19.0"
42+
tempfile = "3.15.0"
4243
tokio-test = "0.4.3"

cloud_scraper/src/core/engine.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ use crate::core::node_handles::NodeHandles;
99
use crate::domain::config::Config;
1010
use crate::domain::mpsc_handle::{one_shot, OneshotMpscSenderHandle};
1111
use crate::domain::node::{LifecycleChannelHandle, Manager};
12-
use crate::domain::oauth2::BasicClientImpl;
12+
use crate::domain::oauth2::{ApplicationSecret, ExtraParameters};
13+
use crate::domain::oauth2::{BasicClientImpl, Client};
1314
use crate::integration::google::Source as GoogleSource;
1415
use crate::integration::log::Sink as LogSink;
1516
use crate::integration::stub::Source as StubSource;
1617
use crate::server::WebServer;
1718
use core::time::Duration;
1819
#[cfg(test)]
1920
use mockall::automock;
21+
use std::path::Path;
2022
use std::sync::atomic::AtomicBool;
2123
use std::sync::atomic::Ordering::SeqCst;
2224
use tokio::sync::Semaphore;
@@ -67,7 +69,7 @@ where
6769
let wait_duration = self.manager.core_config().exit_after();
6870

6971
let mut stub_source = StubSource::new(&self.manager);
70-
let google_source: GoogleSource<BasicClientImpl> =
72+
let google_source: GoogleSource =
7173
GoogleSource::new(&self.manager, self.server.get_web_channel_handle());
7274
let mut log_sink = LogSink::new(&self.manager, &stub_source.get_readonly_channel_handle());
7375

@@ -98,7 +100,15 @@ where
98100
.expect("Could not acquire semaphore");
99101
abort_handles.push(join_set.spawn(async move { log_sink.run(log_permit).await }));
100102
abort_handles.push(join_set.spawn(async move { stub_source.run(stub_permit).await }));
101-
abort_handles.push(join_set.spawn(async move { google_source.run(google_permit).await }));
103+
abort_handles.push(join_set.spawn(async move {
104+
google_source
105+
.run(
106+
google_permit,
107+
BasicClientImpl::get_auth_config,
108+
BasicClientImpl::new,
109+
)
110+
.await
111+
}));
102112

103113
let server = self.server.clone();
104114
abort_handles.push(join_set.spawn(async move {

cloud_scraper/src/domain/module_state.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
use async_trait::async_trait;
2+
use std::path::PathBuf;
23
use tokio::fs;
34

45
#[async_trait]
56
pub trait ModuleState {
67
fn path() -> &'static str;
7-
async fn path_for<Module>() -> Result<String, std::io::Error>
8+
async fn path_for<Module>() -> Result<PathBuf, std::io::Error>
89
where
910
Module: NamedModule,
1011
{
11-
let path = format!("{}/{}", Self::path(), Module::name());
12-
fs::create_dir_all(path.clone()).await?;
12+
Self::path_for_name(Module::name()).await
13+
}
14+
15+
async fn path_for_name(name: &str) -> Result<PathBuf, std::io::Error> {
16+
let path = PathBuf::from(Self::path()).join(name);
17+
fs::create_dir_all(&path).await?;
1318
Ok(path)
1419
}
1520
}

cloud_scraper/src/domain/oauth2/application_secret.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ impl ApplicationSecret {
3939
.set_redirect_uri(redirect_uri)
4040
}
4141
}
42+
43+
#[cfg(test)]
44+
mod tests {
45+
use super::*;
46+
use crate::assert_is_send_and_sync;
47+
48+
fn send_and_sync() {
49+
assert_is_send_and_sync!(ApplicationSecret);
50+
}
51+
}

cloud_scraper/src/domain/oauth2/client.rs

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use crate::core::module::State;
2+
use crate::domain::module_state::ModuleState;
13
use crate::domain::mpsc_handle::one_shot;
24
use crate::domain::node::Manager;
35
use crate::domain::oauth2::extra_parameters::{ExtraParameters, WithExtraParametersExt};
46
use crate::domain::oauth2::token::{BasicTokenResponseExt, Token, TokenExt, TokenStatus};
5-
use crate::domain::oauth2::ApplicationSecret;
7+
use crate::domain::oauth2::{ApplicationSecret, Config, PersistableConfig};
8+
use crate::integration::google::auth::ConfigQuery;
69
use crate::server::Event::Redirect;
710
use crate::server::{Code, Event, WebEventChannelHandle};
811
use crate::static_init::error::Error::FailedAfterRetries;
@@ -18,7 +21,9 @@ use oauth2::{
1821
RefreshToken, Scope,
1922
};
2023
use std::future::Future;
24+
use std::io;
2125
use std::path::{Path, PathBuf};
26+
use std::pin::Pin;
2227
use std::sync::Arc;
2328
use tokio::fs;
2429
use tokio::sync::broadcast::error::RecvError;
@@ -28,18 +33,28 @@ use tokio::time::sleep;
2833
use Error::Oauth2CsrfMismatch;
2934
use Event::Oauth2Code;
3035

31-
pub trait Client: Clone + Send + Sized + Sync + 'static {
36+
pub trait Client: Send + Sync + 'static {
3237
fn new(
3338
application_secret: ApplicationSecret,
3439
extra_parameters: &ExtraParameters,
3540
manager: &Manager,
3641
token_path: &Path,
3742
web_channel_handle: &WebEventChannelHandle,
38-
) -> Self;
39-
fn get_token(
40-
&self,
41-
scopes: &[&str],
42-
) -> impl Future<Output = Result<AccessToken, Error>> + Send + Sync;
43+
) -> Pin<Box<Self>>
44+
where
45+
Self: Sized + 'static;
46+
fn get_auth_config<'async_trait>(
47+
name: &'async_trait str,
48+
) -> Pin<Box<dyn Future<Output = Result<impl Config, io::Error>> + Send + 'async_trait>>
49+
where
50+
Self: Sized + Sync + 'async_trait;
51+
fn duplicate(&self) -> Pin<Box<dyn Client>>;
52+
fn get_token<'async_trait>(
53+
&'async_trait self,
54+
scopes: &'async_trait [&'async_trait str],
55+
) -> Pin<Box<dyn Future<Output = Result<AccessToken, Error>> + Send + 'async_trait>>
56+
where
57+
Self: Sync + 'async_trait;
4358
}
4459

4560
#[derive(Clone)]
@@ -60,31 +75,62 @@ impl Client for BasicClientImpl {
6075
manager: &Manager,
6176
token_path: &Path,
6277
web_channel_handle: &WebEventChannelHandle,
63-
) -> Self {
78+
) -> Pin<Box<Self>> {
6479
let basic_client = application_secret.to_client();
65-
Self {
80+
Box::pin(Self {
6681
basic_client,
6782
extra_parameters: extra_parameters.clone(),
6883
manager: manager.clone(),
6984
retry_max: 9,
7085
retry_period: std::time::Duration::from_secs(2),
7186
token_path: token_path.to_owned(),
7287
web_channel_handle: web_channel_handle.clone(),
73-
}
88+
})
7489
}
7590

76-
async fn get_token(&self, scopes: &[&str]) -> Result<AccessToken, Error> {
77-
match self.get_token_status_from_file().await {
78-
TokenStatus::Ok(token) => Ok(token.access_token().clone()),
79-
TokenStatus::Expired(refresh_token) => self
80-
.refresh_token(&refresh_token)
81-
.await
82-
.map(|token| token.access_token().clone()),
83-
TokenStatus::Absent => self.retrieve_token(scopes).await.map(|token| {
84-
debug!("Token retrieved: {:?}", token);
85-
token.access_token().clone()
86-
}),
87-
}
91+
fn get_auth_config<'async_trait>(
92+
name: &'async_trait str,
93+
) -> Pin<Box<dyn Future<Output = Result<impl Config, io::Error>> + Send + 'async_trait>>
94+
where
95+
Self: Sized + Sync + 'async_trait,
96+
{
97+
Box::pin(
98+
async move { Ok(ConfigQuery::read_config(&State::path_for_name(name).await?).await?) },
99+
)
100+
}
101+
102+
fn duplicate(&self) -> Pin<Box<dyn Client>> {
103+
Box::pin(Self {
104+
basic_client: self.basic_client.clone(),
105+
extra_parameters: self.extra_parameters.clone(),
106+
manager: self.manager.clone(),
107+
retry_max: self.retry_max,
108+
retry_period: self.retry_period,
109+
token_path: self.token_path.to_owned(),
110+
web_channel_handle: self.web_channel_handle.clone(),
111+
})
112+
}
113+
114+
fn get_token<'async_trait>(
115+
&'async_trait self,
116+
scopes: &'async_trait [&'async_trait str],
117+
) -> Pin<Box<dyn Future<Output = Result<AccessToken, Error>> + Send + 'async_trait>>
118+
where
119+
Self: Sync + 'async_trait,
120+
{
121+
Box::pin(async move {
122+
match self.get_token_status_from_file().await {
123+
TokenStatus::Ok(token) => Ok(token.access_token().clone()),
124+
TokenStatus::Expired(refresh_token) => self
125+
.refresh_token(&refresh_token)
126+
.await
127+
.map(|token| token.access_token().clone()),
128+
TokenStatus::Absent => self.retrieve_token(scopes).await.map(|token| {
129+
debug!("Token retrieved: {:?}", token);
130+
token.access_token().clone()
131+
}),
132+
}
133+
})
88134
}
89135
}
90136

@@ -300,6 +346,16 @@ impl BasicClientImpl {
300346
pub mod tests {
301347
use super::*;
302348

349+
mod send_and_sync {
350+
use super::*;
351+
use crate::assert_is_send_and_sync;
352+
353+
#[test]
354+
fn basic_client_impl_is_send_and_sync() {
355+
assert_is_send_and_sync!(BasicClientImpl);
356+
}
357+
}
358+
303359
mod make_redirect_url {
304360
use super::*;
305361
use crate::domain::config::tests::test_config;
@@ -381,12 +437,61 @@ pub mod tests {
381437
}
382438

383439
mod access_token {
440+
use crate::assert_is_send_and_sync;
384441
use oauth2::AccessToken;
385442

386443
#[test]
387444
fn test_is_send_and_sync() {
388-
fn is_send_and_sync<T: Send + Sync>() {}
389-
is_send_and_sync::<AccessToken>();
445+
assert_is_send_and_sync!(AccessToken);
446+
}
447+
}
448+
449+
mod get_auth_config {
450+
use super::*;
451+
use crate::domain::config::Config as CoreConfig;
452+
use crate::domain::module_state::NamedModule;
453+
use crate::domain::oauth2::config::ConfigProperties;
454+
455+
pub struct NamedType;
456+
457+
impl NamedModule for NamedType {
458+
fn name() -> &'static str {
459+
"test"
460+
}
461+
}
462+
463+
impl NamedType {
464+
async fn typed_test<T: Client>(&self) -> ApplicationSecret {
465+
task::spawn(async move {
466+
let config = T::get_auth_config("name").await.unwrap();
467+
assert_eq!(config.auth_uri(), "auth_uri");
468+
assert_eq!(
469+
config.auth_provider_x509_cert_url(),
470+
"auth_provider_x509_cert_url"
471+
);
472+
assert_eq!(config.client_email(), Some("client_email"));
473+
assert_eq!(config.client_id(), "client_id");
474+
assert_eq!(config.client_secret(), "client_secret");
475+
assert_eq!(config.client_x509_cert_url(), Some("client_x509_cert_url"));
476+
assert_eq!(config.project_id(), "project_id");
477+
assert_eq!(config.redirect_uris(), &vec!["redirect_uris".to_string()]);
478+
assert_eq!(config.token_uri(), "token_uri");
479+
let core_config = CoreConfig::with_all_properties(None, None, None, None);
480+
let app_secret = config.to_application_secret(&core_config);
481+
app_secret
482+
})
483+
.await
484+
.unwrap()
485+
}
486+
}
487+
488+
// Trying to reproduce
489+
// lifetime bound not satisfied
490+
// Note: this is a known limitation that will be removed in the future (see issue #100013 <https:// github. com/ rust-lang/ rust/ issues/ 100013> for more information)
491+
#[tokio::test]
492+
async fn test_get_auth_config() {
493+
let named_type = NamedType {};
494+
let _app_secret = named_type.typed_test::<BasicClientImpl>().await;
390495
}
391496
}
392497
}

0 commit comments

Comments
 (0)