Skip to content

Commit 5391346

Browse files
committed
perf: SIMD neon support
First pass at neon support, building off #132
1 parent fbb0bdd commit 5391346

File tree

2 files changed

+354
-0
lines changed

2 files changed

+354
-0
lines changed

src/simd/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
any(
44
target_arch = "x86",
55
target_arch = "x86_64",
6+
target_arch = "aarch64",
67
),
78
)))]
89
mod fallback;
@@ -12,6 +13,7 @@ mod fallback;
1213
any(
1314
target_arch = "x86",
1415
target_arch = "x86_64",
16+
target_arch = "aarch64",
1517
),
1618
)))]
1719
pub use self::fallback::*;
@@ -129,3 +131,15 @@ mod avx2_compile_time {
129131
),
130132
))]
131133
pub use self::avx2_compile_time::*;
134+
135+
#[cfg(all(
136+
httparse_simd,
137+
target_arch = "aarch64",
138+
))]
139+
mod neon;
140+
141+
#[cfg(all(
142+
httparse_simd,
143+
target_arch = "aarch64",
144+
))]
145+
pub use self::neon::*;

src/simd/neon.rs

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
use crate::iter::Bytes;
2+
use core::arch::aarch64::*;
3+
4+
// NOTE: net-negative, so unused for now
5+
#[allow(dead_code)]
6+
#[inline]
7+
pub fn match_header_name_vectored(bytes: &mut Bytes) {
8+
while bytes.as_ref().len() >= 16 {
9+
unsafe {
10+
let advance = match_header_name_char_16_neon(bytes.as_ref().as_ptr());
11+
bytes.advance(advance);
12+
13+
if advance != 16 {
14+
break;
15+
}
16+
}
17+
}
18+
}
19+
20+
#[inline]
21+
pub fn match_header_value_vectored(bytes: &mut Bytes) {
22+
while bytes.as_ref().len() >= 16 {
23+
unsafe {
24+
let advance = match_header_value_char_16_neon(bytes.as_ref().as_ptr());
25+
bytes.advance(advance);
26+
27+
if advance != 16 {
28+
break;
29+
}
30+
}
31+
}
32+
}
33+
34+
#[inline]
35+
pub fn match_uri_vectored(bytes: &mut Bytes) {
36+
while bytes.as_ref().len() >= 16 {
37+
unsafe {
38+
let advance = match_url_char_16_neon(bytes.as_ref().as_ptr());
39+
bytes.advance(advance);
40+
41+
if advance != 16 {
42+
break;
43+
}
44+
}
45+
}
46+
}
47+
48+
const fn bit_set(x: u8) -> bool {
49+
// Validates if a byte is a valid header name character
50+
// https://tools.ietf.org/html/rfc7230#section-3.2.6
51+
matches!(x, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~')
52+
}
53+
54+
// A 256-bit bitmap, split into two halves
55+
// lower half contains bits whose higher nibble is <= 7
56+
// higher half contains bits whose higher nibble is >= 8
57+
const fn build_bitmap() -> ([u8; 16], [u8; 16]) {
58+
let mut bitmap_0_7 = [0u8; 16]; // 0x00..0x7F
59+
let mut bitmap_8_15 = [0u8; 16]; // 0x80..0xFF
60+
let mut i = 0;
61+
while i < 256 {
62+
if bit_set(i as u8) {
63+
// Nibbles
64+
let (lo, hi) = (i & 0x0F, i >> 4);
65+
if i < 128 {
66+
bitmap_0_7[lo] |= 1 << hi;
67+
} else {
68+
bitmap_8_15[lo] |= 1 << hi;
69+
}
70+
}
71+
i += 1;
72+
}
73+
(bitmap_0_7, bitmap_8_15)
74+
}
75+
76+
const BITMAPS: ([u8; 16], [u8; 16]) = build_bitmap();
77+
78+
#[inline]
79+
unsafe fn match_header_name_char_16_neon(ptr: *const u8) -> usize {
80+
let bitmaps = BITMAPS;
81+
// NOTE: ideally compile-time constants
82+
let (bitmap_0_7, _bitmap_8_15) = bitmaps;
83+
let bitmap_0_7 = vld1q_u8(bitmap_0_7.as_ptr());
84+
// let bitmap_8_15 = vld1q_u8(bitmap_8_15.as_ptr());
85+
86+
// Initialize the bitmask_lookup.
87+
const BITMASK_LOOKUP_DATA: [u8; 16] =
88+
[1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
89+
let bitmask_lookup = vld1q_u8(BITMASK_LOOKUP_DATA.as_ptr());
90+
91+
// Load 16 input bytes.
92+
let input = vld1q_u8(ptr);
93+
94+
// Extract indices for row_0_7.
95+
let indices_0_7 = vandq_u8(input, vdupq_n_u8(0x8F)); // 0b1000_1111;
96+
97+
// Extract indices for row_8_15.
98+
// let msb = vandq_u8(input, vdupq_n_u8(0x80));
99+
// let indices_8_15 = veorq_u8(indices_0_7, msb);
100+
101+
// Fetch row_0_7 and row_8_15.
102+
let row_0_7 = vqtbl1q_u8(bitmap_0_7, indices_0_7);
103+
// let row_8_15 = vqtbl1q_u8(bitmap_8_15, indices_8_15);
104+
105+
// Calculate a bitmask, i.e. (1 << hi_nibble % 8).
106+
let bitmask = vqtbl1q_u8(bitmask_lookup, vshrq_n_u8(input, 4));
107+
108+
// Choose rows halves depending on higher nibbles.
109+
// let bitsets = vorrq_u8(row_0_7, row_8_15);
110+
let bitsets = row_0_7;
111+
112+
// Finally check which bytes belong to the set.
113+
let tmp = vandq_u8(bitsets, bitmask);
114+
let result = vceqq_u8(tmp, bitmask);
115+
116+
offsetz(result) as usize
117+
}
118+
119+
// Reference impl of neon 256-bitset filter
120+
#[allow(dead_code)]
121+
unsafe fn naive_bitmap256_match(ptr: *const u8) -> usize {
122+
let bitmaps = BITMAPS;
123+
// NOTE: ideally compile-time constants
124+
let (bitmap_0_7, bitmap_8_15) = bitmaps;
125+
let bitmap_0_7 = vld1q_u8(bitmap_0_7.as_ptr());
126+
let bitmap_8_15 = vld1q_u8(bitmap_8_15.as_ptr());
127+
128+
// Initialize the bitmask_lookup.
129+
let bitmask_lookup_data = [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
130+
let bitmask_lookup = vld1q_u8(bitmask_lookup_data.as_ptr());
131+
132+
// Load 16 input bytes.
133+
let input = vld1q_u8(ptr);
134+
135+
// Extract indices for row_0_7.
136+
let indices_0_7 = vandq_u8(input, vdupq_n_u8(0x8F)); // 0b1000_1111;
137+
138+
// Extract indices for row_8_15.
139+
let msb = vandq_u8(input, vdupq_n_u8(0x80));
140+
let indices_8_15 = veorq_u8(indices_0_7, msb);
141+
142+
// Fetch row_0_7 and row_8_15.
143+
let row_0_7 = vqtbl1q_u8(bitmap_0_7, indices_0_7);
144+
let row_8_15 = vqtbl1q_u8(bitmap_8_15, indices_8_15);
145+
146+
// Calculate a bitmask, i.e. (1 << hi_nibble % 8).
147+
let bitmask = vqtbl1q_u8(bitmask_lookup, vshrq_n_u8(input, 4));
148+
149+
// Choose rows halves depending on higher nibbles.
150+
let bitsets = vorrq_u8(row_0_7, row_8_15);
151+
152+
// Finally check which bytes belong to the set.
153+
let tmp = vandq_u8(bitsets, bitmask);
154+
let result = vceqq_u8(tmp, bitmask);
155+
156+
offsetz(result) as usize
157+
}
158+
159+
#[inline]
160+
unsafe fn match_url_char_16_neon(ptr: *const u8) -> usize {
161+
let input = vld1q_u8(ptr);
162+
163+
// Check that b'!' <= input <= b'~'
164+
let result = vandq_u8(
165+
vcleq_u8(vdupq_n_u8(b'!'), input),
166+
vcleq_u8(input, vdupq_n_u8(b'~')),
167+
);
168+
// Check that input != b'<' and input != b'>'
169+
let lt = vceqq_u8(input, vdupq_n_u8(b'<'));
170+
let gt = vceqq_u8(input, vdupq_n_u8(b'>'));
171+
let ltgt = vorrq_u8(lt, gt);
172+
// Nand with result
173+
let result = vbicq_u8(result, ltgt);
174+
175+
offsetz(result) as usize
176+
}
177+
178+
#[inline]
179+
unsafe fn match_header_value_char_16_neon(ptr: *const u8) -> usize {
180+
let input = vld1q_u8(ptr);
181+
182+
// Check that b' ' <= and b != 127 or b == 9
183+
let result = vcleq_u8(vdupq_n_u8(b' '), input);
184+
185+
// Allow tab
186+
let tab = vceqq_u8(input, vdupq_n_u8(0x09));
187+
let result = vorrq_u8(result, tab);
188+
189+
// Disallow del
190+
let del = vceqq_u8(input, vdupq_n_u8(0x7F));
191+
let result = vbicq_u8(result, del);
192+
193+
offsetz(result) as usize
194+
}
195+
196+
#[inline]
197+
unsafe fn offsetz(x: uint8x16_t) -> u32 {
198+
// NOT the vector since it's faster to operate with zeros instead
199+
offsetnz(vmvnq_u8(x))
200+
}
201+
202+
#[inline]
203+
unsafe fn offsetnz(x: uint8x16_t) -> u32 {
204+
// Extract two u64
205+
let x = vreinterpretq_u64_u8(x);
206+
let low: u64 = std::mem::transmute(vget_low_u64(x));
207+
let high: u64 = std::mem::transmute(vget_high_u64(x));
208+
209+
#[inline]
210+
fn clz(x: u64) -> u32 {
211+
// perf: rust will unroll this loop
212+
// and it's much faster than rbit + clz so voila
213+
for (i, b) in x.to_ne_bytes().iter().copied().enumerate() {
214+
if b != 0 {
215+
return i as u32;
216+
}
217+
}
218+
8 // Technically not reachable since zero-guarded
219+
}
220+
221+
// NOTE: need to revisit given offsetz perf is critical to overall perf
222+
// clz using "binary search" masking, more complex but possibly faster
223+
#[inline]
224+
#[allow(dead_code)]
225+
fn clzb(mut value: u64) -> u32 {
226+
let mut offset = 0;
227+
228+
// Level 1: Check 32-bit halves
229+
const MASK_32: u64 = 0x00000000FFFFFFFF;
230+
if (value & MASK_32) == 0 {
231+
offset += 4;
232+
value >>= 32;
233+
}
234+
235+
// Level 2: Check 16-bit quarters
236+
const MASK_16: u64 = 0x0000FFFF;
237+
if (value & MASK_16) == 0 {
238+
offset += 2;
239+
value >>= 16;
240+
}
241+
242+
// Level 3: Check 8-bit octets
243+
const MASK_8: u64 = 0x00FF;
244+
if (value & MASK_8) == 0 {
245+
offset += 1;
246+
}
247+
248+
offset
249+
}
250+
251+
if low != 0 {
252+
return clz(low);
253+
} else if high != 0 {
254+
return 8 + clz(high);
255+
} else {
256+
return 16;
257+
}
258+
}
259+
260+
// NOTE: Slower than offsetnz, somewhat "surprisingly", despite using fewer instructions,
261+
// possibly lookup table loading is inefficient and isn't hoisted out of the loops
262+
#[inline]
263+
#[allow(dead_code)]
264+
unsafe fn offsetz2(input: uint8x16_t) -> u32 {
265+
const LOOKUP_TABLE: [u8; 16] = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
266+
const LOOKUP: uint8x16_t = unsafe { std::mem::transmute(LOOKUP_TABLE) };
267+
268+
let mask = input;
269+
let indexed: uint8x16_t = vbicq_u8(LOOKUP, mask);
270+
271+
let offset: u8 = vmaxvq_u8(indexed);
272+
273+
16 - (offset as u32)
274+
}
275+
276+
#[test]
277+
fn neon_code_matches_uri_chars_table() {
278+
unsafe {
279+
assert!(byte_is_allowed(b'_', match_uri_vectored));
280+
281+
for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
282+
assert_eq!(
283+
byte_is_allowed(b as u8, match_uri_vectored),
284+
allowed,
285+
"byte_is_allowed({:?}) should be {:?}",
286+
b,
287+
allowed,
288+
);
289+
}
290+
}
291+
}
292+
293+
#[test]
294+
fn neon_code_matches_header_value_chars_table() {
295+
unsafe {
296+
assert!(byte_is_allowed(b'_', match_header_value_vectored));
297+
298+
for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() {
299+
assert_eq!(
300+
byte_is_allowed(b as u8, match_header_value_vectored),
301+
allowed,
302+
"byte_is_allowed({:?}) should be {:?}",
303+
b,
304+
allowed,
305+
);
306+
}
307+
}
308+
}
309+
310+
#[test]
311+
fn neon_code_matches_header_name_chars_table() {
312+
unsafe {
313+
assert!(byte_is_allowed(b'_', match_header_name_vectored));
314+
315+
for (b, allowed) in crate::HEADER_NAME_MAP.iter().cloned().enumerate() {
316+
assert_eq!(
317+
byte_is_allowed(b as u8, match_header_name_vectored),
318+
allowed,
319+
"byte_is_allowed({:?}) should be {:?}",
320+
b,
321+
allowed,
322+
);
323+
}
324+
}
325+
}
326+
327+
#[cfg(test)]
328+
unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool {
329+
let mut slice = [b'_'; 16];
330+
slice[10] = byte;
331+
let mut bytes = Bytes::new(&slice);
332+
333+
f(&mut bytes);
334+
335+
match bytes.pos() {
336+
16 => true,
337+
10 => false,
338+
x => panic!("unexpected pos: {}", x),
339+
}
340+
}

0 commit comments

Comments
 (0)