diff --git a/src/lib.rs b/src/lib.rs index 3232852..2d27674 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -547,7 +547,7 @@ impl<'h, 'b> Request<'h, 'b> { newline!(bytes); let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let parse_headers_status = parse_headers_iter_uninit( &mut headers, &mut bytes, &HeaderParserConfig { @@ -556,11 +556,14 @@ impl<'h, 'b> Request<'h, 'b> { allow_space_before_first_header_name: config.allow_space_before_first_header_name, ignore_invalid_headers: config.ignore_invalid_headers_in_requests }, - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + match parse_headers_status { + Status::Complete(headers_len) => Ok(Status::Complete(len + headers_len)), + Status::Partial => Ok(Status::Partial), + } } /// Try to parse a buffer of bytes into the Request, @@ -583,7 +586,7 @@ impl<'h, 'b> Request<'h, 'b> { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), + Ok(status) => Ok(status), other => { // put the original headers back self.headers = &mut *(headers as *mut [Header<'_>]); @@ -687,7 +690,7 @@ impl<'h, 'b> Response<'h, 'b> { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), + Ok(status) => Ok(status), other => { // put the original headers back self.headers = &mut *(headers as *mut [Header<'_>]); @@ -745,7 +748,7 @@ impl<'h, 'b> Response<'h, 'b> { let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let parse_headers_status = parse_headers_iter_uninit( &mut headers, &mut bytes, &HeaderParserConfig { @@ -754,10 +757,14 @@ impl<'h, 'b> Response<'h, 'b> { allow_space_before_first_header_name: config.allow_space_before_first_header_name, ignore_invalid_headers: config.ignore_invalid_headers_in_responses } - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + + match parse_headers_status { + Status::Complete(headers_len) => Ok(Status::Complete(len + headers_len)), + Status::Partial => Ok(Status::Partial), + } } } @@ -1386,6 +1393,7 @@ pub fn parse_chunk_size(buf: &[u8]) #[cfg(test)] mod tests { + use core::mem::MaybeUninit; use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size}; const NUM_OF_HEADERS: usize = 4; @@ -1559,7 +1567,7 @@ mod tests { assert_eq!(req.method.unwrap(), "GET"); assert_eq!(req.path.unwrap(), "/"); assert_eq!(req.version.unwrap(), 1); - assert_eq!(req.headers.len(), NUM_OF_HEADERS); // doesn't slice since not Complete + assert_eq!(req.headers.len(), 1); assert_eq!(req.headers[0].name, "Host"); assert_eq!(req.headers[0].value, b"yolo"); } @@ -1760,7 +1768,7 @@ mod tests { assert_eq!(res.version.unwrap(), 1); assert_eq!(res.code.unwrap(), 200); assert_eq!(res.reason.unwrap(), "OK"); - assert_eq!(res.headers.len(), NUM_OF_HEADERS); // doesn't slice since not Complete + assert_eq!(res.headers.len(), 1); assert_eq!(res.headers[0].name, "Server"); assert_eq!(res.headers[0].value, b"yolo"); } @@ -2591,4 +2599,52 @@ mod tests { assert_eq!(response.headers[0].name, "Space-Before-Header"); assert_eq!(response.headers[0].value, &b"hello there"[..]); } + + #[test] + fn test_request_partial_with_uninit_headers() { + const REQUEST: &[u8] = b"GET / HTTP/1.1\r\nFoo: bar\r\nBaz: quux\r\n"; + + let mut headers = unsafe { + MaybeUninit::<[MaybeUninit>; 4]>::uninit().assume_init() + }; + + let mut request = Request::new(&mut []); + + let result = crate::ParserConfig::default() + .parse_request_with_uninit_headers(&mut request, REQUEST, &mut headers); + + assert_eq!(result, Ok(Status::Partial)); + assert_eq!(request.method.unwrap(), "GET"); + assert_eq!(request.path.unwrap(), "/"); + assert_eq!(request.version.unwrap(), 1); + assert_eq!(request.headers.len(), 2); + assert_eq!(request.headers[0].name, "Foo"); + assert_eq!(request.headers[0].value, &b"bar"[..]); + assert_eq!(request.headers[1].name, "Baz"); + assert_eq!(request.headers[1].value, &b"quux"[..]); + } + + #[test] + fn test_response_partial_with_uninit_headers() { + const RESPONSE: &[u8] = b"HTTP/1.1 200 OK\r\nFoo: bar\r\nBaz: quux\r\n"; + + let mut headers = unsafe { + MaybeUninit::<[MaybeUninit>; 4]>::uninit().assume_init() + }; + + let mut response = Response::new(&mut []); + + let result = crate::ParserConfig::default() + .parse_response_with_uninit_headers(&mut response, RESPONSE, &mut headers); + + assert_eq!(result, Ok(Status::Partial)); + assert_eq!(response.version.unwrap(), 1); + assert_eq!(response.code.unwrap(), 200); + assert_eq!(response.reason.unwrap(), "OK"); + assert_eq!(response.headers.len(), 2); + assert_eq!(response.headers[0].name, "Foo"); + assert_eq!(response.headers[0].value, &b"bar"[..]); + assert_eq!(response.headers[1].name, "Baz"); + assert_eq!(response.headers[1].value, &b"quux"[..]); + } }