-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[Compile] Add NEON implementation for bf16->fp32 cast #134297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Let's trigger a dashboard run for this. |
Sure, https://github.com/pytorch/pytorch/actions/runs/10529131469 [Edit] Realized I did this change before the split, so alas it's not really usable. Let's test in trunk |
int32x4_t shift = vdupq_n_s32(16); | ||
auto u16_low1 = vget_low_u16(u16_8); | ||
auto u16_high1 = vget_high_u16(u16_8); | ||
float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_low1), shift)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable, but if input is interleaved then you can just do vectorized (input & 0xFF00) and the reinterpret, for upper half and save get_high, and movl instructions for upper half. For lower hafl you would still need those.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "This is weird: workflow dispatch jobs do not show up in the signal box, but still delay the merge" |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This changes assembly generated for the following routine ```cpp void bfloat16tofloat(c10::BFloat16* in, float* out) { auto tmp0 = at::vec::Vectorized<c10::BFloat16>::loadu(in, 8); auto tmp1 = at::vec::convert<float>(tmp0); tmp1.store(out); } ``` from ```asm bfloat16tofloat(c10::BFloat16*, float*): 0000000000000034 stp x29, x30, [sp, #-0x10]! 0000000000000038 mov x29, sp 000000000000003c sub x9, sp, #0x90 0000000000000040 and sp, x9, #0xffffffffffffffe0 0000000000000044 mov x8, #0x0 0000000000000048 adrp x9, 0 ; 0x0 000000000000004c ldr x9, [x9] 0000000000000050 ldr x9, [x9] 0000000000000054 str x9, [sp, #0x88] 0000000000000058 stp xzr, xzr, [sp, #0x10] 000000000000005c ldr q0, [x0] 0000000000000060 str q0, [sp] 0000000000000064 ldr q1, [sp, #0x10] 0000000000000068 stp q0, q1, [sp, #0x20] 000000000000006c add x9, sp, #0x40 0000000000000070 add x10, sp, #0x20 0000000000000074 add x11, x10, x8 0000000000000078 ldp d0, d1, [x11] 000000000000007c shll.4s v0, v0, #16 0000000000000080 shll.4s v1, v1, #16 0000000000000084 stp q0, q1, [x9], #0x20 0000000000000088 add x8, x8, #0x10 000000000000008c cmp x8, #0x20 0000000000000090 b.ne 0x74 0000000000000094 add x8, sp, #0x40 0000000000000098 ld1.4s { v0, v1 }, [x8] 000000000000009c st1.4s { v0, v1 }, [x1] 00000000000000a0 ldr x8, [sp, #0x88] 00000000000000a4 adrp x9, 0 ; 0x0 00000000000000a8 ldr x9, [x9] 00000000000000ac ldr x9, [x9] 00000000000000b0 cmp x9, x8 00000000000000b4 b.ne 0xc4 00000000000000b8 mov sp, x29 00000000000000bc ldp x29, x30, [sp], #0x10 00000000000000c0 ret 00000000000000c4 bl 0xc4 ``` to ```asm bfloat16tofloat(c10::BFloat16*, float*): 0000000000000034 ldr q0, [x0] 0000000000000038 shll.4s v1, v0, #16 000000000000003c shll2.4s v2, v0, #16 0000000000000040 st1.4s { v1, v2 }, [x1] 0000000000000044 ret ``` And as result speeds up `python3 torchchat.py generate stories110M --num-samples 3 --compile --device cpu --dtype bfloat16` from 33 to 90 tokens/sec Pull Request resolved: #134297 Approved by: https://github.com/kimishpatel
This changes assembly generated for the following routine
from
to
And as result speeds up
python3 torchchat.py generate stories110M --num-samples 3 --compile --device cpu --dtype bfloat16
from 33 to 90 tokens/seccc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10