Skip to content

Commit d0295a5

Browse files
eduardosmAmanieu
authored andcommitted
Implement AVX512BW 16-bit shift by immediate (srai_epi16) with simd_shr instead of LLVM intrinsics
1 parent 1a7e1e3 commit d0295a5

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5996,9 +5996,7 @@ pub unsafe fn _mm_maskz_sra_epi16(k: __mmask8, a: __m128i, count: __m128i) -> __
59965996
#[rustc_legacy_const_generics(1)]
59975997
pub unsafe fn _mm512_srai_epi16<const IMM8: u32>(a: __m512i) -> __m512i {
59985998
static_assert_uimm_bits!(IMM8, 8);
5999-
let a = a.as_i16x32();
6000-
let r = vpsraiw(a, IMM8);
6001-
transmute(r)
5999+
transmute(simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16)))
60026000
}
60036001

60046002
/// Shift packed 16-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6014,8 +6012,7 @@ pub unsafe fn _mm512_mask_srai_epi16<const IMM8: u32>(
60146012
a: __m512i,
60156013
) -> __m512i {
60166014
static_assert_uimm_bits!(IMM8, 8);
6017-
let a = a.as_i16x32();
6018-
let shf = vpsraiw(a, IMM8);
6015+
let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16));
60196016
transmute(simd_select_bitmask(k, shf, src.as_i16x32()))
60206017
}
60216018

@@ -6028,9 +6025,8 @@ pub unsafe fn _mm512_mask_srai_epi16<const IMM8: u32>(
60286025
#[rustc_legacy_const_generics(2)]
60296026
pub unsafe fn _mm512_maskz_srai_epi16<const IMM8: u32>(k: __mmask32, a: __m512i) -> __m512i {
60306027
static_assert_uimm_bits!(IMM8, 8);
6031-
let a = a.as_i16x32();
6032-
let shf = vpsraiw(a, IMM8);
6033-
let zero = _mm512_setzero_si512().as_i16x32();
6028+
let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16));
6029+
let zero = i16x32::splat(0);
60346030
transmute(simd_select_bitmask(k, shf, zero))
60356031
}
60366032

@@ -6047,8 +6043,7 @@ pub unsafe fn _mm256_mask_srai_epi16<const IMM8: u32>(
60476043
a: __m256i,
60486044
) -> __m256i {
60496045
static_assert_uimm_bits!(IMM8, 8);
6050-
let imm8 = IMM8 as i32;
6051-
let r = psraiw256(a.as_i16x16(), imm8);
6046+
let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16));
60526047
transmute(simd_select_bitmask(k, r, src.as_i16x16()))
60536048
}
60546049

@@ -6061,9 +6056,8 @@ pub unsafe fn _mm256_mask_srai_epi16<const IMM8: u32>(
60616056
#[rustc_legacy_const_generics(2)]
60626057
pub unsafe fn _mm256_maskz_srai_epi16<const IMM8: u32>(k: __mmask16, a: __m256i) -> __m256i {
60636058
static_assert_uimm_bits!(IMM8, 8);
6064-
let imm8 = IMM8 as i32;
6065-
let r = psraiw256(a.as_i16x16(), imm8);
6066-
let zero = _mm256_setzero_si256().as_i16x16();
6059+
let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16));
6060+
let zero = i16x16::splat(0);
60676061
transmute(simd_select_bitmask(k, r, zero))
60686062
}
60696063

@@ -6080,8 +6074,7 @@ pub unsafe fn _mm_mask_srai_epi16<const IMM8: u32>(
60806074
a: __m128i,
60816075
) -> __m128i {
60826076
static_assert_uimm_bits!(IMM8, 8);
6083-
let imm8 = IMM8 as i32;
6084-
let r = psraiw128(a.as_i16x8(), imm8);
6077+
let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16));
60856078
transmute(simd_select_bitmask(k, r, src.as_i16x8()))
60866079
}
60876080

@@ -6094,9 +6087,8 @@ pub unsafe fn _mm_mask_srai_epi16<const IMM8: u32>(
60946087
#[rustc_legacy_const_generics(2)]
60956088
pub unsafe fn _mm_maskz_srai_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i {
60966089
static_assert_uimm_bits!(IMM8, 8);
6097-
let imm8 = IMM8 as i32;
6098-
let r = psraiw128(a.as_i16x8(), imm8);
6099-
let zero = _mm_setzero_si128().as_i16x8();
6090+
let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16));
6091+
let zero = i16x8::splat(0);
61006092
transmute(simd_select_bitmask(k, r, zero))
61016093
}
61026094

@@ -10013,13 +10005,6 @@ extern "C" {
1001310005

1001410006
#[link_name = "llvm.x86.avx512.psra.w.512"]
1001510007
fn vpsraw(a: i16x32, count: i16x8) -> i16x32;
10016-
#[link_name = "llvm.x86.avx512.psrai.w.512"]
10017-
fn vpsraiw(a: i16x32, imm8: u32) -> i16x32;
10018-
10019-
#[link_name = "llvm.x86.avx2.psrai.w"]
10020-
fn psraiw256(a: i16x16, imm8: i32) -> i16x16;
10021-
#[link_name = "llvm.x86.sse2.psrai.w"]
10022-
fn psraiw128(a: i16x8, imm8: i32) -> i16x8;
1002310008

1002410009
#[link_name = "llvm.x86.avx512.psrav.w.512"]
1002510010
fn vpsravw(a: i16x32, count: i16x32) -> i16x32;

0 commit comments

Comments
 (0)