Skip to content

client: support custom pool keys #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions src/client/legacy/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll};
use std::time::Duration;

Expand Down Expand Up @@ -35,15 +36,16 @@ type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The
/// underlying connection pool will be reused.
#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
pub struct Client<C, B> {
pub struct Client<C, B, PK: pool::Key = DefaultPoolKey> {
config: Config,
connector: C,
exec: Exec,
#[cfg(feature = "http1")]
h1_builder: hyper::client::conn::http1::Builder,
#[cfg(feature = "http2")]
h2_builder: hyper::client::conn::http2::Builder<Exec>,
pool: pool::Pool<PoolClient<B>, PoolKey>,
pool_key: Arc<dyn Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static>,
pool: pool::Pool<PoolClient<B>, PK>,
}

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -90,7 +92,7 @@ macro_rules! e {
}

// We might change this... :shrug:
type PoolKey = (http::uri::Scheme, http::uri::Authority);
type DefaultPoolKey = (http::uri::Scheme, http::uri::Authority);

enum TrySendError<B> {
Retryable {
Expand Down Expand Up @@ -143,12 +145,13 @@ impl Client<(), ()> {
}
}

impl<C, B> Client<C, B>
impl<C, B, PK> Client<C, B, PK>
where
C: Connect + Clone + Send + Sync + 'static,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
PK: pool::Key,
{
/// Send a `GET` request to the supplied `Uri`.
///
Expand Down Expand Up @@ -214,40 +217,29 @@ where
/// # }
/// # fn main() {}
/// ```
pub fn request(&self, mut req: Request<B>) -> ResponseFuture {
let is_http_connect = req.method() == Method::CONNECT;
match req.version() {
Version::HTTP_11 => (),
Version::HTTP_10 => {
if is_http_connect {
warn!("CONNECT is not allowed for HTTP/1.0");
return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod)));
}
}
Version::HTTP_2 => (),
// completely unsupported HTTP version (like HTTP/0.9)!
other => return ResponseFuture::error_version(other),
};

let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
pub fn request(&self, req: Request<B>) -> ResponseFuture {
let (mut parts, body) = req.into_parts();
let pool_key = match (self.pool_key)(&mut parts) {
Ok(s) => s,
Err(err) => {
return ResponseFuture::new(future::err(err));
}
};
let req = Request::from_parts(parts, body);

ResponseFuture::new(self.clone().send_request(req, pool_key))
}

async fn send_request(
self,
mut req: Request<B>,
pool_key: PoolKey,
pool_key: PK,
) -> Result<Response<hyper::body::Incoming>, Error> {
let uri = req.uri().clone();

loop {
req = match self.try_send_request(req, pool_key.clone()).await {
let pk: PK = pool_key.clone();
req = match self.try_send_request(req, pk).await {
Ok(resp) => return Ok(resp),
Err(TrySendError::Nope(err)) => return Err(err),
Err(TrySendError::Retryable {
Expand Down Expand Up @@ -275,10 +267,11 @@ where
async fn try_send_request(
&self,
mut req: Request<B>,
pool_key: PoolKey,
pool_key: PK,
) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> {
let uri = req.uri().clone();
let mut pooled = self
.connection_for(pool_key)
.connection_for(uri, pool_key)
.await
// `connection_for` already retries checkout errors, so if
// it returns an error, there's not much else to retry
Expand Down Expand Up @@ -381,10 +374,12 @@ where

async fn connection_for(
&self,
pool_key: PoolKey,
) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> {
uri: Uri,
pool_key: PK,
) -> Result<pool::Pooled<PoolClient<B>, PK>, Error> {
loop {
match self.one_connection_for(pool_key.clone()).await {
let pk: PK = pool_key.clone();
match self.one_connection_for(uri.clone(), pk).await {
Ok(pooled) => return Ok(pooled),
Err(ClientConnectError::Normal(err)) => return Err(err),
Err(ClientConnectError::CheckoutIsClosed(reason)) => {
Expand All @@ -404,12 +399,13 @@ where

async fn one_connection_for(
&self,
pool_key: PoolKey,
) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> {
uri: Uri,
pool_key: PK,
) -> Result<pool::Pooled<PoolClient<B>, PK>, ClientConnectError> {
// Return a single connection if pooling is not enabled
if !self.pool.is_enabled() {
return self
.connect_to(pool_key)
.connect_to(uri, pool_key)
.await
.map_err(ClientConnectError::Normal);
}
Expand All @@ -428,7 +424,7 @@ where
// connection future is spawned into the runtime to complete,
// and then be inserted into the pool as an idle connection.
let checkout = self.pool.checkout(pool_key.clone());
let connect = self.connect_to(pool_key);
let connect = self.connect_to(uri, pool_key);
let is_ver_h2 = self.config.ver == Ver::Http2;

// The order of the `select` is depended on below...
Expand Down Expand Up @@ -497,9 +493,9 @@ where
#[cfg(any(feature = "http1", feature = "http2"))]
fn connect_to(
&self,
pool_key: PoolKey,
) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin
{
dst: Uri,
pool_key: PK,
) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PK>, Error>> + Send + Unpin {
let executor = self.exec.clone();
let pool = self.pool.clone();
#[cfg(feature = "http1")]
Expand All @@ -509,7 +505,6 @@ where
let ver = self.config.ver;
let is_ver_h2 = ver == Ver::Http2;
let connector = self.connector.clone();
let dst = domain_as_uri(pool_key.clone());
hyper_lazy(move || {
// Try to take a "connecting lock".
//
Expand Down Expand Up @@ -720,8 +715,8 @@ where
}
}

impl<C: Clone, B> Clone for Client<C, B> {
fn clone(&self) -> Client<C, B> {
impl<C: Clone, B, PK: pool::Key> Clone for Client<C, B, PK> {
fn clone(&self) -> Client<C, B, PK> {
Client {
config: self.config,
exec: self.exec.clone(),
Expand All @@ -730,6 +725,7 @@ impl<C: Clone, B> Clone for Client<C, B> {
#[cfg(feature = "http2")]
h2_builder: self.h2_builder.clone(),
connector: self.connector.clone(),
pool_key: self.pool_key.clone(),
pool: self.pool.clone(),
}
}
Expand All @@ -752,11 +748,6 @@ impl ResponseFuture {
inner: SyncWrapper::new(Box::pin(value)),
}
}

fn error_version(ver: Version) -> Self {
warn!("Request has unsupported version \"{:?}\"", ver);
ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion))))
}
}

impl fmt::Debug for ResponseFuture {
Expand Down Expand Up @@ -950,7 +941,28 @@ fn authority_form(uri: &mut Uri) {
};
}

fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> {
fn default_pool_key(req: &mut http::request::Parts) -> Result<DefaultPoolKey, Error> {
let is_http_connect = req.method == Method::CONNECT;
match req.version {
Version::HTTP_11 => (),
Version::HTTP_10 => {
if is_http_connect {
warn!("CONNECT is not allowed for HTTP/1.0");
return Err(e!(UserUnsupportedRequestMethod));
}
}
Version::HTTP_2 => (),
// completely unsupported HTTP version (like HTTP/0.9)!
other => {
warn!("Request has unsupported version \"{:?}\"", other);
return Err(e!(UserUnsupportedVersion));
}
};

extract_domain(&mut req.uri, is_http_connect)
}

fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<DefaultPoolKey, Error> {
let uri_clone = uri.clone();
match (uri_clone.scheme(), uri_clone.authority()) {
(Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
Expand All @@ -974,15 +986,6 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error
}
}

fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
http::uri::Builder::new()
.scheme(scheme)
.authority(auth)
.path_and_query("/")
.build()
.expect("domain is valid Uri")
}

fn set_scheme(uri: &mut Uri, scheme: Scheme) {
debug_assert!(
uri.scheme().is_none(),
Expand Down Expand Up @@ -1602,11 +1605,27 @@ impl Builder {
}

/// Combine the configuration of this builder with a connector to create a `Client`.
pub fn build<C, B>(&self, connector: C) -> Client<C, B>
pub fn build<'a, C, B>(&'a self, connector: C) -> Client<C, B, DefaultPoolKey>
where
C: Connect + Clone,
B: Body + Send,
B::Data: Send,
{
self.build_with_pool_key::<C, B, DefaultPoolKey>(connector, default_pool_key)
}

/// Combine the configuration of this builder with a connector to create a `Client`, with a custom pooling key.
/// A function to extract the pool key from the request is required.
pub fn build_with_pool_key<C, B, PK>(
&self,
connector: C,
pool_key: impl Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static,
) -> Client<C, B, PK>
where
C: Connect + Clone,
B: Body + Send,
B::Data: Send,
PK: pool::Key,
{
let exec = self.exec.clone();
let timer = self.pool_timer.clone();
Expand All @@ -1618,7 +1637,8 @@ impl Builder {
#[cfg(feature = "http2")]
h2_builder: self.h2_builder.clone(),
connector,
pool: pool::Pool::new(self.pool_config, exec, timer),
pool_key: Arc::new(pool_key),
pool: pool::Pool::<_, PK>::new(self.pool_config, exec, timer),
}
}
}
Expand Down