Skip to content

Commit 9f31e38

Browse files
committed
client: support custom pool keys
This PR introduces the ability to use a custom pool key with the legacy Client. #16 added generic support to `Pool` itself long ago, but it was never fully finished to allow a custom key in `Client`. To support backwards compatibility, a default PoolKey is assigned to the generic parameter which seems reasonably backwards compatible(?) When providing a custom pool key, the user is required to also pass a constructor that generates the pool key from the `http::request::Parts`. I had also considered making a user pass in the PoolKey as part of the `request()` call, but I think this is a worse experience generally and harder to be backwards compatible. If a user did want to do a per-request key they can always set an extension in `Parts` so this approach is equally flexible.
1 parent d9107d0 commit 9f31e38

File tree

1 file changed

+70
-53
lines changed

1 file changed

+70
-53
lines changed

src/client/legacy/client.rs

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::error::Error as StdError;
88
use std::fmt;
99
use std::future::Future;
1010
use std::pin::Pin;
11+
use std::sync::Arc;
1112
use std::task::{self, Poll};
1213
use std::time::Duration;
1314

@@ -35,15 +36,16 @@ type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
3536
/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The
3637
/// underlying connection pool will be reused.
3738
#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
38-
pub struct Client<C, B> {
39+
pub struct Client<C, B, PK: pool::Key = DefaultPoolKey> {
3940
config: Config,
4041
connector: C,
4142
exec: Exec,
4243
#[cfg(feature = "http1")]
4344
h1_builder: hyper::client::conn::http1::Builder,
4445
#[cfg(feature = "http2")]
4546
h2_builder: hyper::client::conn::http2::Builder<Exec>,
46-
pool: pool::Pool<PoolClient<B>, PoolKey>,
47+
pool_key: Arc<dyn Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static>,
48+
pool: pool::Pool<PoolClient<B>, PK>,
4749
}
4850

4951
#[derive(Clone, Copy, Debug)]
@@ -90,7 +92,7 @@ macro_rules! e {
9092
}
9193

9294
// We might change this... :shrug:
93-
type PoolKey = (http::uri::Scheme, http::uri::Authority);
95+
type DefaultPoolKey = (http::uri::Scheme, http::uri::Authority);
9496

9597
enum TrySendError<B> {
9698
Retryable {
@@ -143,12 +145,13 @@ impl Client<(), ()> {
143145
}
144146
}
145147

146-
impl<C, B> Client<C, B>
148+
impl<C, B, PK> Client<C, B, PK>
147149
where
148150
C: Connect + Clone + Send + Sync + 'static,
149151
B: Body + Send + 'static + Unpin,
150152
B::Data: Send,
151153
B::Error: Into<Box<dyn StdError + Send + Sync>>,
154+
PK: pool::Key,
152155
{
153156
/// Send a `GET` request to the supplied `Uri`.
154157
///
@@ -214,35 +217,23 @@ where
214217
/// # }
215218
/// # fn main() {}
216219
/// ```
217-
pub fn request(&self, mut req: Request<B>) -> ResponseFuture {
218-
let is_http_connect = req.method() == Method::CONNECT;
219-
match req.version() {
220-
Version::HTTP_11 => (),
221-
Version::HTTP_10 => {
222-
if is_http_connect {
223-
warn!("CONNECT is not allowed for HTTP/1.0");
224-
return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod)));
225-
}
226-
}
227-
Version::HTTP_2 => (),
228-
// completely unsupported HTTP version (like HTTP/0.9)!
229-
other => return ResponseFuture::error_version(other),
230-
};
231-
232-
let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
220+
pub fn request(&self, req: Request<B>) -> ResponseFuture {
221+
let (mut parts, body) = req.into_parts();
222+
let pool_key = match (self.pool_key)(&mut parts) {
233223
Ok(s) => s,
234224
Err(err) => {
235225
return ResponseFuture::new(future::err(err));
236226
}
237227
};
228+
let req = Request::from_parts(parts, body);
238229

239230
ResponseFuture::new(self.clone().send_request(req, pool_key))
240231
}
241232

242233
async fn send_request(
243234
self,
244235
mut req: Request<B>,
245-
pool_key: PoolKey,
236+
pool_key: PK,
246237
) -> Result<Response<hyper::body::Incoming>, Error> {
247238
let uri = req.uri().clone();
248239

@@ -275,10 +266,10 @@ where
275266
async fn try_send_request(
276267
&self,
277268
mut req: Request<B>,
278-
pool_key: PoolKey,
269+
pool_key: PK,
279270
) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> {
280271
let mut pooled = self
281-
.connection_for(pool_key)
272+
.connection_for(req.uri().clone(), pool_key)
282273
.await
283274
// `connection_for` already retries checkout errors, so if
284275
// it returns an error, there's not much else to retry
@@ -381,10 +372,11 @@ where
381372

382373
async fn connection_for(
383374
&self,
384-
pool_key: PoolKey,
385-
) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> {
375+
uri: Uri,
376+
pool_key: PK,
377+
) -> Result<pool::Pooled<PoolClient<B>, PK>, Error> {
386378
loop {
387-
match self.one_connection_for(pool_key.clone()).await {
379+
match self.one_connection_for(uri.clone(), pool_key.clone()).await {
388380
Ok(pooled) => return Ok(pooled),
389381
Err(ClientConnectError::Normal(err)) => return Err(err),
390382
Err(ClientConnectError::CheckoutIsClosed(reason)) => {
@@ -404,12 +396,13 @@ where
404396

405397
async fn one_connection_for(
406398
&self,
407-
pool_key: PoolKey,
408-
) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> {
399+
uri: Uri,
400+
pool_key: PK,
401+
) -> Result<pool::Pooled<PoolClient<B>, PK>, ClientConnectError> {
409402
// Return a single connection if pooling is not enabled
410403
if !self.pool.is_enabled() {
411404
return self
412-
.connect_to(pool_key)
405+
.connect_to(uri, pool_key)
413406
.await
414407
.map_err(ClientConnectError::Normal);
415408
}
@@ -428,7 +421,7 @@ where
428421
// connection future is spawned into the runtime to complete,
429422
// and then be inserted into the pool as an idle connection.
430423
let checkout = self.pool.checkout(pool_key.clone());
431-
let connect = self.connect_to(pool_key);
424+
let connect = self.connect_to(uri, pool_key);
432425
let is_ver_h2 = self.config.ver == Ver::Http2;
433426

434427
// The order of the `select` is depended on below...
@@ -497,9 +490,9 @@ where
497490
#[cfg(any(feature = "http1", feature = "http2"))]
498491
fn connect_to(
499492
&self,
500-
pool_key: PoolKey,
501-
) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin
502-
{
493+
dst: Uri,
494+
pool_key: PK,
495+
) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PK>, Error>> + Send + Unpin {
503496
let executor = self.exec.clone();
504497
let pool = self.pool.clone();
505498
#[cfg(feature = "http1")]
@@ -509,7 +502,6 @@ where
509502
let ver = self.config.ver;
510503
let is_ver_h2 = ver == Ver::Http2;
511504
let connector = self.connector.clone();
512-
let dst = domain_as_uri(pool_key.clone());
513505
hyper_lazy(move || {
514506
// Try to take a "connecting lock".
515507
//
@@ -720,8 +712,8 @@ where
720712
}
721713
}
722714

723-
impl<C: Clone, B> Clone for Client<C, B> {
724-
fn clone(&self) -> Client<C, B> {
715+
impl<C: Clone, B, PK: pool::Key> Clone for Client<C, B, PK> {
716+
fn clone(&self) -> Client<C, B, PK> {
725717
Client {
726718
config: self.config,
727719
exec: self.exec.clone(),
@@ -730,6 +722,7 @@ impl<C: Clone, B> Clone for Client<C, B> {
730722
#[cfg(feature = "http2")]
731723
h2_builder: self.h2_builder.clone(),
732724
connector: self.connector.clone(),
725+
pool_key: self.pool_key.clone(),
733726
pool: self.pool.clone(),
734727
}
735728
}
@@ -752,11 +745,6 @@ impl ResponseFuture {
752745
inner: SyncWrapper::new(Box::pin(value)),
753746
}
754747
}
755-
756-
fn error_version(ver: Version) -> Self {
757-
warn!("Request has unsupported version \"{:?}\"", ver);
758-
ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion))))
759-
}
760748
}
761749

762750
impl fmt::Debug for ResponseFuture {
@@ -950,7 +938,28 @@ fn authority_form(uri: &mut Uri) {
950938
};
951939
}
952940

953-
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> {
941+
fn default_pool_key(req: &mut http::request::Parts) -> Result<DefaultPoolKey, Error> {
942+
let is_http_connect = req.method == Method::CONNECT;
943+
match req.version {
944+
Version::HTTP_11 => (),
945+
Version::HTTP_10 => {
946+
if is_http_connect {
947+
warn!("CONNECT is not allowed for HTTP/1.0");
948+
return Err(e!(UserUnsupportedRequestMethod));
949+
}
950+
}
951+
Version::HTTP_2 => (),
952+
// completely unsupported HTTP version (like HTTP/0.9)!
953+
other => {
954+
warn!("Request has unsupported version \"{:?}\"", other);
955+
return Err(e!(UserUnsupportedVersion));
956+
}
957+
};
958+
959+
extract_domain(&mut req.uri, is_http_connect)
960+
}
961+
962+
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<DefaultPoolKey, Error> {
954963
let uri_clone = uri.clone();
955964
match (uri_clone.scheme(), uri_clone.authority()) {
956965
(Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
@@ -974,15 +983,6 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error
974983
}
975984
}
976985

977-
fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
978-
http::uri::Builder::new()
979-
.scheme(scheme)
980-
.authority(auth)
981-
.path_and_query("/")
982-
.build()
983-
.expect("domain is valid Uri")
984-
}
985-
986986
fn set_scheme(uri: &mut Uri, scheme: Scheme) {
987987
debug_assert!(
988988
uri.scheme().is_none(),
@@ -1602,11 +1602,27 @@ impl Builder {
16021602
}
16031603

16041604
/// Combine the configuration of this builder with a connector to create a `Client`.
1605-
pub fn build<C, B>(&self, connector: C) -> Client<C, B>
1605+
pub fn build<'a, C, B>(&'a self, connector: C) -> Client<C, B, DefaultPoolKey>
1606+
where
1607+
C: Connect + Clone,
1608+
B: Body + Send,
1609+
B::Data: Send,
1610+
{
1611+
self.build_with_pool_key::<C, B, DefaultPoolKey>(connector, default_pool_key)
1612+
}
1613+
1614+
/// Combine the configuration of this builder with a connector to create a `Client`, with a custom pooling key.
1615+
/// A function to extract the pool key from the request is required.
1616+
pub fn build_with_pool_key<C, B, PK>(
1617+
&self,
1618+
connector: C,
1619+
pool_key: impl Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static,
1620+
) -> Client<C, B, PK>
16061621
where
16071622
C: Connect + Clone,
16081623
B: Body + Send,
16091624
B::Data: Send,
1625+
PK: pool::Key,
16101626
{
16111627
let exec = self.exec.clone();
16121628
let timer = self.pool_timer.clone();
@@ -1618,7 +1634,8 @@ impl Builder {
16181634
#[cfg(feature = "http2")]
16191635
h2_builder: self.h2_builder.clone(),
16201636
connector,
1621-
pool: pool::Pool::new(self.pool_config, exec, timer),
1637+
pool_key: Arc::new(pool_key),
1638+
pool: pool::Pool::<_, PK>::new(self.pool_config, exec, timer),
16221639
}
16231640
}
16241641
}

0 commit comments

Comments
 (0)