@@ -5996,9 +5996,7 @@ pub unsafe fn _mm_maskz_sra_epi16(k: __mmask8, a: __m128i, count: __m128i) -> __
5996
5996
#[rustc_legacy_const_generics(1)]
5997
5997
pub unsafe fn _mm512_srai_epi16<const IMM8: u32>(a: __m512i) -> __m512i {
5998
5998
static_assert_uimm_bits!(IMM8, 8);
5999
- let a = a.as_i16x32();
6000
- let r = vpsraiw(a, IMM8);
6001
- transmute(r)
5999
+ transmute(simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16)))
6002
6000
}
6003
6001
6004
6002
/// Shift packed 16-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6014,8 +6012,7 @@ pub unsafe fn _mm512_mask_srai_epi16<const IMM8: u32>(
6014
6012
a: __m512i,
6015
6013
) -> __m512i {
6016
6014
static_assert_uimm_bits!(IMM8, 8);
6017
- let a = a.as_i16x32();
6018
- let shf = vpsraiw(a, IMM8);
6015
+ let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16));
6019
6016
transmute(simd_select_bitmask(k, shf, src.as_i16x32()))
6020
6017
}
6021
6018
@@ -6028,9 +6025,8 @@ pub unsafe fn _mm512_mask_srai_epi16<const IMM8: u32>(
6028
6025
#[rustc_legacy_const_generics(2)]
6029
6026
pub unsafe fn _mm512_maskz_srai_epi16<const IMM8: u32>(k: __mmask32, a: __m512i) -> __m512i {
6030
6027
static_assert_uimm_bits!(IMM8, 8);
6031
- let a = a.as_i16x32();
6032
- let shf = vpsraiw(a, IMM8);
6033
- let zero = _mm512_setzero_si512().as_i16x32();
6028
+ let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16));
6029
+ let zero = i16x32::splat(0);
6034
6030
transmute(simd_select_bitmask(k, shf, zero))
6035
6031
}
6036
6032
@@ -6047,8 +6043,7 @@ pub unsafe fn _mm256_mask_srai_epi16<const IMM8: u32>(
6047
6043
a: __m256i,
6048
6044
) -> __m256i {
6049
6045
static_assert_uimm_bits!(IMM8, 8);
6050
- let imm8 = IMM8 as i32;
6051
- let r = psraiw256(a.as_i16x16(), imm8);
6046
+ let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16));
6052
6047
transmute(simd_select_bitmask(k, r, src.as_i16x16()))
6053
6048
}
6054
6049
@@ -6061,9 +6056,8 @@ pub unsafe fn _mm256_mask_srai_epi16<const IMM8: u32>(
6061
6056
#[rustc_legacy_const_generics(2)]
6062
6057
pub unsafe fn _mm256_maskz_srai_epi16<const IMM8: u32>(k: __mmask16, a: __m256i) -> __m256i {
6063
6058
static_assert_uimm_bits!(IMM8, 8);
6064
- let imm8 = IMM8 as i32;
6065
- let r = psraiw256(a.as_i16x16(), imm8);
6066
- let zero = _mm256_setzero_si256().as_i16x16();
6059
+ let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16));
6060
+ let zero = i16x16::splat(0);
6067
6061
transmute(simd_select_bitmask(k, r, zero))
6068
6062
}
6069
6063
@@ -6080,8 +6074,7 @@ pub unsafe fn _mm_mask_srai_epi16<const IMM8: u32>(
6080
6074
a: __m128i,
6081
6075
) -> __m128i {
6082
6076
static_assert_uimm_bits!(IMM8, 8);
6083
- let imm8 = IMM8 as i32;
6084
- let r = psraiw128(a.as_i16x8(), imm8);
6077
+ let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16));
6085
6078
transmute(simd_select_bitmask(k, r, src.as_i16x8()))
6086
6079
}
6087
6080
@@ -6094,9 +6087,8 @@ pub unsafe fn _mm_mask_srai_epi16<const IMM8: u32>(
6094
6087
#[rustc_legacy_const_generics(2)]
6095
6088
pub unsafe fn _mm_maskz_srai_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i {
6096
6089
static_assert_uimm_bits!(IMM8, 8);
6097
- let imm8 = IMM8 as i32;
6098
- let r = psraiw128(a.as_i16x8(), imm8);
6099
- let zero = _mm_setzero_si128().as_i16x8();
6090
+ let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16));
6091
+ let zero = i16x8::splat(0);
6100
6092
transmute(simd_select_bitmask(k, r, zero))
6101
6093
}
6102
6094
@@ -10013,13 +10005,6 @@ extern "C" {
10013
10005
10014
10006
#[link_name = "llvm.x86.avx512.psra.w.512"]
10015
10007
fn vpsraw(a: i16x32, count: i16x8) -> i16x32;
10016
- #[link_name = "llvm.x86.avx512.psrai.w.512"]
10017
- fn vpsraiw(a: i16x32, imm8: u32) -> i16x32;
10018
-
10019
- #[link_name = "llvm.x86.avx2.psrai.w"]
10020
- fn psraiw256(a: i16x16, imm8: i32) -> i16x16;
10021
- #[link_name = "llvm.x86.sse2.psrai.w"]
10022
- fn psraiw128(a: i16x8, imm8: i32) -> i16x8;
10023
10008
10024
10009
#[link_name = "llvm.x86.avx512.psrav.w.512"]
10025
10010
fn vpsravw(a: i16x32, count: i16x32) -> i16x32;
0 commit comments