Skip to content

Commit 8fb43ec

Browse files
authored
[SDPA] remove keyword static in exp (#3291)
* [sdpa] remove keyword static in exp
1 parent b0cd68b commit 8fb43ec

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,24 @@ inline Vectorized<scalar_t> exp_u20(Vectorized<scalar_t> data) {
130130
inline Vectorized<float> exp_u20(Vectorized<float> data) {
131131
__m512 values = __m512(data);
132132
// A faster version of exp with ULP=20
133-
static __m512 vec_factorial_1 =
134-
_mm512_set1_ps(0.999999701f); // 1/factorial(1)
135-
static __m512 vec_factorial_2 =
136-
_mm512_set1_ps(0.499991506f); // 1/factorial(2)
137-
static __m512 vec_factorial_3 =
138-
_mm512_set1_ps(0.166676521f); // 1/factorial(3)
139-
static __m512 vec_factorial_4 =
133+
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); // 1/factorial(1)
134+
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); // 1/factorial(2)
135+
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); // 1/factorial(3)
136+
const __m512 vec_factorial_4 =
140137
_mm512_set1_ps(0.0418978221f); // 1/factorial(4)
141-
static __m512 vec_factorial_5 =
138+
const __m512 vec_factorial_5 =
142139
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
143-
static __m512 vec_exp_log2ef =
140+
const __m512 vec_exp_log2ef =
144141
(__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
145-
static __m512 vec_half = _mm512_set1_ps(0.5f);
146-
static __m512 vec_one = _mm512_set1_ps(1.f);
147-
static __m512 vec_zero = _mm512_set1_ps(0.f);
148-
static __m512 vec_two = _mm512_set1_ps(2.f);
149-
static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
150-
static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
151-
static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
152-
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
153-
static int n_mantissa_bits = 23;
142+
const __m512 vec_half = _mm512_set1_ps(0.5f);
143+
const __m512 vec_one = _mm512_set1_ps(1.f);
144+
const __m512 vec_zero = _mm512_set1_ps(0.f);
145+
const __m512 vec_two = _mm512_set1_ps(2.f);
146+
const __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
147+
const __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
148+
const __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
149+
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
150+
const int n_mantissa_bits = 23;
154151

155152
// exp(x) =
156153
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem

0 commit comments

Comments
 (0)