Skip to content

Commit 96fab05

Browse files
authored
io: wrappers for inspecting data on IO resources (#5033)
1 parent 9b87daa commit 96fab05

File tree

3 files changed

+331
-0
lines changed

3 files changed

+331
-0
lines changed

tokio-util/src/io/inspect.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use futures_core::ready;
2+
use pin_project_lite::pin_project;
3+
use std::io::{IoSlice, Result};
4+
use std::pin::Pin;
5+
use std::task::{Context, Poll};
6+
7+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8+
9+
pin_project! {
10+
/// An adapter that lets you inspect the data that's being read.
11+
///
12+
/// This is useful for things like hashing data as it's read in.
13+
pub struct InspectReader<R, F> {
14+
#[pin]
15+
reader: R,
16+
f: F,
17+
}
18+
}
19+
20+
impl<R, F> InspectReader<R, F> {
21+
/// Create a new InspectReader, wrapping `reader` and calling `f` for the
22+
/// new data supplied by each read call.
23+
///
24+
/// The closure will only be called with an empty slice if the inner reader
25+
/// returns without reading data into the buffer. This happens at EOF, or if
26+
/// `poll_read` is called with a zero-size buffer.
27+
pub fn new(reader: R, f: F) -> InspectReader<R, F>
28+
where
29+
R: AsyncRead,
30+
F: FnMut(&[u8]),
31+
{
32+
InspectReader { reader, f }
33+
}
34+
35+
/// Consumes the `InspectReader`, returning the wrapped reader
36+
pub fn into_inner(self) -> R {
37+
self.reader
38+
}
39+
}
40+
41+
impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
42+
fn poll_read(
43+
self: Pin<&mut Self>,
44+
cx: &mut Context<'_>,
45+
buf: &mut ReadBuf<'_>,
46+
) -> Poll<Result<()>> {
47+
let me = self.project();
48+
let filled_length = buf.filled().len();
49+
ready!(me.reader.poll_read(cx, buf))?;
50+
(me.f)(&buf.filled()[filled_length..]);
51+
Poll::Ready(Ok(()))
52+
}
53+
}
54+
55+
pin_project! {
56+
/// An adapter that lets you inspect the data that's being written.
57+
///
58+
/// This is useful for things like hashing data as it's written out.
59+
pub struct InspectWriter<W, F> {
60+
#[pin]
61+
writer: W,
62+
f: F,
63+
}
64+
}
65+
66+
impl<W, F> InspectWriter<W, F> {
67+
/// Create a new InspectWriter, wrapping `write` and calling `f` for the
68+
/// data successfully written by each write call.
69+
///
70+
/// The closure `f` will never be called with an empty slice. A vectored
71+
/// write can result in multiple calls to `f` - at most one call to `f` per
72+
/// buffer supplied to `poll_write_vectored`.
73+
pub fn new(writer: W, f: F) -> InspectWriter<W, F>
74+
where
75+
W: AsyncWrite,
76+
F: FnMut(&[u8]),
77+
{
78+
InspectWriter { writer, f }
79+
}
80+
81+
/// Consumes the `InspectWriter`, returning the wrapped writer
82+
pub fn into_inner(self) -> W {
83+
self.writer
84+
}
85+
}
86+
87+
impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
88+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
89+
let me = self.project();
90+
let res = me.writer.poll_write(cx, buf);
91+
if let Poll::Ready(Ok(count)) = res {
92+
if count != 0 {
93+
(me.f)(&buf[..count]);
94+
}
95+
}
96+
res
97+
}
98+
99+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
100+
let me = self.project();
101+
me.writer.poll_flush(cx)
102+
}
103+
104+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
105+
let me = self.project();
106+
me.writer.poll_shutdown(cx)
107+
}
108+
109+
fn poll_write_vectored(
110+
self: Pin<&mut Self>,
111+
cx: &mut Context<'_>,
112+
bufs: &[IoSlice<'_>],
113+
) -> Poll<Result<usize>> {
114+
let me = self.project();
115+
let res = me.writer.poll_write_vectored(cx, bufs);
116+
if let Poll::Ready(Ok(mut count)) = res {
117+
for buf in bufs {
118+
if count == 0 {
119+
break;
120+
}
121+
let size = count.min(buf.len());
122+
if size != 0 {
123+
(me.f)(&buf[..size]);
124+
count -= size;
125+
}
126+
}
127+
}
128+
res
129+
}
130+
131+
fn is_write_vectored(&self) -> bool {
132+
self.writer.is_write_vectored()
133+
}
134+
}

tokio-util/src/io/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
1111
//! [`AsyncRead`]: tokio::io::AsyncRead
1212
13+
mod inspect;
1314
mod read_buf;
1415
mod reader_stream;
1516
mod stream_reader;
17+
1618
cfg_io_util! {
1719
mod sync_bridge;
1820
pub use self::sync_bridge::SyncIoBridge;
1921
}
2022

23+
pub use self::inspect::{InspectReader, InspectWriter};
2124
pub use self::read_buf::read_buf;
2225
pub use self::reader_stream::ReaderStream;
2326
pub use self::stream_reader::StreamReader;

tokio-util/tests/io_inspect.rs

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
use futures::future::poll_fn;
2+
use std::{
3+
io::IoSlice,
4+
pin::Pin,
5+
task::{Context, Poll},
6+
};
7+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
8+
use tokio_util::io::{InspectReader, InspectWriter};
9+
10+
/// An AsyncRead implementation that works byte-by-byte, to catch out callers
11+
/// who don't allow for `buf` being part-filled before the call
12+
struct SmallReader {
13+
contents: Vec<u8>,
14+
}
15+
16+
impl Unpin for SmallReader {}
17+
18+
impl AsyncRead for SmallReader {
19+
fn poll_read(
20+
mut self: Pin<&mut Self>,
21+
_cx: &mut Context<'_>,
22+
buf: &mut ReadBuf<'_>,
23+
) -> Poll<std::io::Result<()>> {
24+
if let Some(byte) = self.contents.pop() {
25+
buf.put_slice(&[byte])
26+
}
27+
Poll::Ready(Ok(()))
28+
}
29+
}
30+
31+
#[tokio::test]
32+
async fn read_tee() {
33+
let contents = b"This could be really long, you know".to_vec();
34+
let reader = SmallReader {
35+
contents: contents.clone(),
36+
};
37+
let mut altout: Vec<u8> = Vec::new();
38+
let mut teeout = Vec::new();
39+
{
40+
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
41+
tee.read_to_end(&mut teeout).await.unwrap();
42+
}
43+
assert_eq!(teeout, altout);
44+
assert_eq!(altout.len(), contents.len());
45+
}
46+
47+
/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
48+
/// that reads the whole of the first buffer plus one byte from the second in
49+
/// poll_write_vectored.
50+
///
51+
/// This is designed to catch bugs in handling partially written buffers
52+
#[derive(Debug)]
53+
struct SmallWriter {
54+
contents: Vec<u8>,
55+
}
56+
57+
impl Unpin for SmallWriter {}
58+
59+
impl AsyncWrite for SmallWriter {
60+
fn poll_write(
61+
mut self: Pin<&mut Self>,
62+
_cx: &mut Context<'_>,
63+
buf: &[u8],
64+
) -> Poll<Result<usize, std::io::Error>> {
65+
// Just write one byte at a time
66+
if buf.is_empty() {
67+
return Poll::Ready(Ok(0));
68+
}
69+
self.contents.push(buf[0]);
70+
Poll::Ready(Ok(1))
71+
}
72+
73+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
74+
Poll::Ready(Ok(()))
75+
}
76+
77+
fn poll_shutdown(
78+
self: Pin<&mut Self>,
79+
_cx: &mut Context<'_>,
80+
) -> Poll<Result<(), std::io::Error>> {
81+
Poll::Ready(Ok(()))
82+
}
83+
84+
fn poll_write_vectored(
85+
mut self: Pin<&mut Self>,
86+
_cx: &mut Context<'_>,
87+
bufs: &[IoSlice<'_>],
88+
) -> Poll<Result<usize, std::io::Error>> {
89+
// Write all of the first buffer, then one byte from the second buffer
90+
// This should trip up anything that doesn't correctly handle multiple
91+
// buffers.
92+
if bufs.is_empty() {
93+
return Poll::Ready(Ok(0));
94+
}
95+
let mut written_len = bufs[0].len();
96+
self.contents.extend_from_slice(&bufs[0]);
97+
98+
if bufs.len() > 1 {
99+
let buf = bufs[1];
100+
if !buf.is_empty() {
101+
written_len += 1;
102+
self.contents.push(buf[0]);
103+
}
104+
}
105+
Poll::Ready(Ok(written_len))
106+
}
107+
108+
fn is_write_vectored(&self) -> bool {
109+
true
110+
}
111+
}
112+
113+
#[tokio::test]
114+
async fn write_tee() {
115+
let mut altout: Vec<u8> = Vec::new();
116+
let mut writeout = SmallWriter {
117+
contents: Vec::new(),
118+
};
119+
{
120+
let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
121+
tee.write_all(b"A testing string, very testing")
122+
.await
123+
.unwrap();
124+
}
125+
assert_eq!(altout, writeout.contents);
126+
}
127+
128+
// This is inefficient, but works well enough for test use.
129+
// If you want something similar for real code, you'll want to avoid all the
130+
// fun of manipulating `bufs` - ideally, by the time you read this,
131+
// IoSlice::advance_slices will be stable, and you can use that.
132+
async fn write_all_vectored<W: AsyncWrite + Unpin>(
133+
mut writer: W,
134+
mut bufs: Vec<Vec<u8>>,
135+
) -> Result<usize, std::io::Error> {
136+
let mut res = 0;
137+
while !bufs.is_empty() {
138+
let mut written = poll_fn(|cx| {
139+
let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
140+
Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
141+
})
142+
.await?;
143+
res += written;
144+
while written > 0 {
145+
let buf_len = bufs[0].len();
146+
if buf_len <= written {
147+
bufs.remove(0);
148+
written -= buf_len;
149+
} else {
150+
let buf = &mut bufs[0];
151+
let drain_len = written.min(buf.len());
152+
buf.drain(..drain_len);
153+
written -= drain_len;
154+
}
155+
}
156+
}
157+
Ok(res)
158+
}
159+
160+
#[tokio::test]
161+
async fn write_tee_vectored() {
162+
let mut altout: Vec<u8> = Vec::new();
163+
let mut writeout = SmallWriter {
164+
contents: Vec::new(),
165+
};
166+
let original = b"A very long string split up";
167+
let bufs: Vec<Vec<u8>> = original
168+
.split(|b| b.is_ascii_whitespace())
169+
.map(Vec::from)
170+
.collect();
171+
assert!(bufs.len() > 1);
172+
let expected: Vec<u8> = {
173+
let mut out = Vec::new();
174+
for item in &bufs {
175+
out.extend_from_slice(item)
176+
}
177+
out
178+
};
179+
{
180+
let mut bufcount = 0;
181+
let tee = InspectWriter::new(&mut writeout, |bytes| {
182+
bufcount += 1;
183+
altout.extend(bytes)
184+
});
185+
186+
assert!(tee.is_write_vectored());
187+
188+
write_all_vectored(tee, bufs.clone()).await.unwrap();
189+
190+
assert!(bufcount >= bufs.len());
191+
}
192+
assert_eq!(altout, writeout.contents);
193+
assert_eq!(writeout.contents, expected);
194+
}

0 commit comments

Comments
 (0)