Skip to content

Commit d725ca4

Browse files
malfetzhangxiaoli73
authored andcommitted
[Compile] Add NEON implementation for bf16->fp32 cast (pytorch#134297)
This changes assembly generated for the following routine ```cpp void bfloat16tofloat(c10::BFloat16* in, float* out) { auto tmp0 = at::vec::Vectorized<c10::BFloat16>::loadu(in, 8); auto tmp1 = at::vec::convert<float>(tmp0); tmp1.store(out); } ``` from ```asm bfloat16tofloat(c10::BFloat16*, float*): 0000000000000034 stp x29, x30, [sp, #-0x10]! 0000000000000038 mov x29, sp 000000000000003c sub x9, sp, #0x90 0000000000000040 and sp, x9, #0xffffffffffffffe0 0000000000000044 mov x8, #0x0 0000000000000048 adrp x9, 0 ; 0x0 000000000000004c ldr x9, [x9] 0000000000000050 ldr x9, [x9] 0000000000000054 str x9, [sp, #0x88] 0000000000000058 stp xzr, xzr, [sp, #0x10] 000000000000005c ldr q0, [x0] 0000000000000060 str q0, [sp] 0000000000000064 ldr q1, [sp, #0x10] 0000000000000068 stp q0, q1, [sp, #0x20] 000000000000006c add x9, sp, #0x40 0000000000000070 add x10, sp, #0x20 0000000000000074 add x11, x10, x8 0000000000000078 ldp d0, d1, [x11] 000000000000007c shll.4s v0, v0, #16 0000000000000080 shll.4s v1, v1, #16 0000000000000084 stp q0, q1, [x9], #0x20 0000000000000088 add x8, x8, #0x10 000000000000008c cmp x8, #0x20 0000000000000090 b.ne 0x74 0000000000000094 add x8, sp, #0x40 0000000000000098 ld1.4s { v0, v1 }, [x8] 000000000000009c st1.4s { v0, v1 }, [x1] 00000000000000a0 ldr x8, [sp, #0x88] 00000000000000a4 adrp x9, 0 ; 0x0 00000000000000a8 ldr x9, [x9] 00000000000000ac ldr x9, [x9] 00000000000000b0 cmp x9, x8 00000000000000b4 b.ne 0xc4 00000000000000b8 mov sp, x29 00000000000000bc ldp x29, x30, [sp], #0x10 00000000000000c0 ret 00000000000000c4 bl 0xc4 ``` to ```asm bfloat16tofloat(c10::BFloat16*, float*): 0000000000000034 ldr q0, [x0] 0000000000000038 shll.4s v1, v0, #16 000000000000003c shll2.4s v2, v0, #16 0000000000000040 st1.4s { v1, v2 }, [x1] 0000000000000044 ret ``` And as result speeds up `python3 torchchat.py generate stories110M --num-samples 3 --compile --device cpu --dtype bfloat16` from 33 to 90 tokens/sec Pull Request resolved: pytorch#134297 Approved by: https://github.com/kimishpatel
1 parent 5b35d34 commit d725ca4

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_convert.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,24 @@ struct VecConvert<
223223
};
224224
#endif
225225

226+
#if defined(CPU_CAPABILITY_NEON)
227+
template <>
228+
struct VecConvert<float, 1, BFloat16, 1> {
229+
static inline VectorizedN<float, 1> apply(
230+
const VectorizedN<BFloat16, 1>& src) {
231+
VectorizedN<float, 1> result;
232+
uint16x8_t u16_8 = vld1q_u16(reinterpret_cast<const uint16_t*>(&src[0]));
233+
int32x4_t shift = vdupq_n_s32(16);
234+
auto u16_low1 = vget_low_u16(u16_8);
235+
auto u16_high1 = vget_high_u16(u16_8);
236+
float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_low1), shift));
237+
float32x4_t f32x4_1 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_high1), shift));
238+
result[0] = {f32x4_0, f32x4_1};
239+
return result;
240+
}
241+
};
242+
#endif
243+
226244
template <typename src_t>
227245
struct VecConvert<
228246
float,

0 commit comments

Comments
 (0)