diff --git a/src/simd/swar.rs b/src/simd/swar.rs index 13f58a8..c950c66 100644 --- a/src/simd/swar.rs +++ b/src/simd/swar.rs @@ -1,17 +1,20 @@ /// SWAR: SIMD Within A Register /// SIMD validator backend that validates register-sized chunks of data at a time. -// TODO: current impl assumes 64-bit registers, optimize for 32-bit use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes}; +// Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8 +const BLOCK_SIZE: usize = core::mem::size_of::(); +type ByteBlock = [u8; BLOCK_SIZE]; + #[inline] pub fn match_uri_vectored(bytes: &mut Bytes) { loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { + if let Some(bytes8) = bytes.peek_n::(BLOCK_SIZE) { let n = match_uri_char_8_swar(bytes8); unsafe { bytes.advance(n); } - if n == 8 { + if n == BLOCK_SIZE { continue; } } @@ -28,12 +31,12 @@ pub fn match_uri_vectored(bytes: &mut Bytes) { #[inline] pub fn match_header_value_vectored(bytes: &mut Bytes) { loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { + if let Some(bytes8) = bytes.peek_n::(BLOCK_SIZE) { let n = match_header_value_char_8_swar(bytes8); unsafe { bytes.advance(n); } - if n == 8 { + if n == BLOCK_SIZE { continue; } } @@ -49,19 +52,19 @@ pub fn match_header_value_vectored(bytes: &mut Bytes) { #[inline] pub fn match_header_name_vectored(bytes: &mut Bytes) { - while let Some(block) = bytes.peek_n::<[u8; 8]>(8) { + while let Some(block) = bytes.peek_n::(BLOCK_SIZE) { let n = match_block(is_header_name_token, block); unsafe { bytes.advance(n); } - if n != 8 { + if n != BLOCK_SIZE { return; } } unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) }; } -// Matches "tail", i.e: when we have <8 bytes in the buffer, should be uncommon +// Matches "tail", i.e: when we have bool, bytes: &[u8]) -> usize { @@ -75,19 +78,19 @@ fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize { // Naive fallback block matcher #[inline(always)] -fn match_block(f: impl Fn(u8) -> bool, block: [u8; 8]) -> usize { +fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize { for (i, &b) in block.iter().enumerate() { if !f(b) { return i; } } - 8 + BLOCK_SIZE } -/// // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) +// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) // creates a u64 whose bytes are each equal to b -const fn uniform_block(b: u8) -> u64 { - b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8] +const fn uniform_block(b: u8) -> usize { + (b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize } // A byte-wise range-check on an enire word/block, @@ -95,15 +98,15 @@ const fn uniform_block(b: u8) -> u64 { // `33 <= x <= 126 && x != '>' && x != '<'` // IMPORTANT: it false negatives if the block contains '?' #[inline] -fn match_uri_char_8_swar(block: [u8; 8]) -> usize { +fn match_uri_char_8_swar(block: ByteBlock) -> usize { // 33 <= x <= 126 const M: u8 = 0x21; const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127 - N); - const M128: u64 = uniform_block(128); + const BM: usize = uniform_block(M); + const BN: usize = uniform_block(127 - N); + const M128: usize = uniform_block(128); - let x = u64::from_ne_bytes(block); // Really just a transmute + let x = usize::from_ne_bytes(block); // Really just a transmute let lt = x.wrapping_sub(BM) & !x; // <= m let gt = x.wrapping_add(BN) | x; // >= n @@ -130,8 +133,8 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize { // } // (xordist(b'<', 2), xordist(b'>', 2)) // ``` - const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap - const BGT: u64 = uniform_block(b'>'); + const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap + const BGT: usize = uniform_block(b'>'); let xgt = x ^ BGT; let ltgtq = xgt.wrapping_sub(B3) & !xgt; @@ -143,15 +146,15 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize { // ensuring all bytes in the word satisfy `32 <= x <= 126` // IMPORTANT: false negatives if obs-text is present (0x80..=0xFF) #[inline] -fn match_header_value_char_8_swar(block: [u8; 8]) -> usize { +fn match_header_value_char_8_swar(block: ByteBlock) -> usize { // 32 <= x <= 126 const M: u8 = 0x20; const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127 - N); - const M128: u64 = uniform_block(128); + const BM: usize = uniform_block(M); + const BN: usize = uniform_block(127 - N); + const M128: usize = uniform_block(128); - let x = u64::from_ne_bytes(block); // Really just a transmute + let x = usize::from_ne_bytes(block); // Really just a transmute let lt = x.wrapping_sub(BM) & !x; // <= m let gt = x.wrapping_add(BN) | x; // >= n offsetnz((lt | gt) & M128) @@ -160,10 +163,10 @@ fn match_header_value_char_8_swar(block: [u8; 8]) -> usize { /// Check block to find offset of first non-zero byte // NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit #[inline] -fn offsetnz(block: u64) -> usize { +fn offsetnz(block: usize) -> usize { // fast path optimistic case (common for long valid sequences) if block == 0 { - return 8; + return BLOCK_SIZE; } // perf: rust will unroll this loop @@ -177,19 +180,19 @@ fn offsetnz(block: u64) -> usize { #[test] fn test_is_header_value_block() { - let is_header_value_block = |b| match_header_value_char_8_swar(b) == 8; + let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE; // 0..32 => false for b in 0..32_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b); } // 32..127 => true for b in 32..127_u8 { - assert_eq!(is_header_value_block([b; 8]), true, "b={}", b); + assert_eq!(is_header_value_block([b; BLOCK_SIZE]), true, "b={}", b); } // 127..=255 => false for b in 127..=255_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b); } // A few sanity checks on non-uniform bytes for safe-measure @@ -199,30 +202,30 @@ fn test_is_header_value_block() { #[test] fn test_is_uri_block() { - let is_uri_block = |b| match_uri_char_8_swar(b) == 8; + let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE; // 0..33 => false for b in 0..33_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b); } // 33..127 => true if b not in { '<', '?', '>' } let falsy = |b| b"".contains(&b); for b in 33..127_u8 { - assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b); + assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b); } // 127..=255 => false for b in 127..=255_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b); } } #[test] fn test_offsetnz() { - let seq = [0_u8; 8]; - for i in 0..8 { + let seq = [0_u8; BLOCK_SIZE]; + for i in 0..BLOCK_SIZE { let mut seq = seq.clone(); seq[i] = 1; - let x = u64::from_ne_bytes(seq); + let x = usize::from_ne_bytes(seq); assert_eq!(offsetnz(x), i); } }