Skip to content

[Compile] Add NEON implementation for bf16->fp32 cast #134297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Aug 23, 2024

This changes assembly generated for the following routine

void bfloat16tofloat(c10::BFloat16* in, float* out) {
        auto tmp0 = at::vec::Vectorized<c10::BFloat16>::loadu(in, 8);
        auto tmp1 = at::vec::convert<float>(tmp0);
        tmp1.store(out);
}

from

bfloat16tofloat(c10::BFloat16*, float*):
0000000000000034        stp     x29, x30, [sp, #-0x10]!
0000000000000038        mov     x29, sp
000000000000003c        sub     x9, sp, #0x90
0000000000000040        and     sp, x9, #0xffffffffffffffe0
0000000000000044        mov     x8, #0x0
0000000000000048        adrp    x9, 0 ; 0x0
000000000000004c        ldr     x9, [x9]
0000000000000050        ldr     x9, [x9]
0000000000000054        str     x9, [sp, #0x88]
0000000000000058        stp     xzr, xzr, [sp, #0x10]
000000000000005c        ldr     q0, [x0]
0000000000000060        str     q0, [sp]
0000000000000064        ldr     q1, [sp, #0x10]
0000000000000068        stp     q0, q1, [sp, #0x20]
000000000000006c        add     x9, sp, #0x40
0000000000000070        add     x10, sp, #0x20
0000000000000074        add     x11, x10, x8
0000000000000078        ldp     d0, d1, [x11]
000000000000007c        shll.4s v0, v0, #16
0000000000000080        shll.4s v1, v1, #16
0000000000000084        stp     q0, q1, [x9], #0x20
0000000000000088        add     x8, x8, #0x10
000000000000008c        cmp     x8, #0x20
0000000000000090        b.ne    0x74
0000000000000094        add     x8, sp, #0x40
0000000000000098        ld1.4s  { v0, v1 }, [x8]
000000000000009c        st1.4s  { v0, v1 }, [x1]
00000000000000a0        ldr     x8, [sp, #0x88]
00000000000000a4        adrp    x9, 0 ; 0x0
00000000000000a8        ldr     x9, [x9]
00000000000000ac        ldr     x9, [x9]
00000000000000b0        cmp     x9, x8
00000000000000b4        b.ne    0xc4
00000000000000b8        mov     sp, x29
00000000000000bc        ldp     x29, x30, [sp], #0x10
00000000000000c0        ret
00000000000000c4        bl      0xc4

to

bfloat16tofloat(c10::BFloat16*, float*):
0000000000000034        ldr     q0, [x0]
0000000000000038        shll.4s v1, v0, #16
000000000000003c        shll2.4s        v2, v0, #16
0000000000000040        st1.4s  { v1, v2 }, [x1]
0000000000000044        ret

And as result speeds up python3 torchchat.py generate stories110M --num-samples 3 --compile --device cpu --dtype bfloat16 from 33 to 90 tokens/sec

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134297

Note: Links to docs will display an error until the docs builds have been completed.

❌ 9 New Failures

As of commit 6cb3771 with merge base d1abd62 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Aug 23, 2024
@malfet malfet requested review from jgong5 and desertfire August 23, 2024 00:40
@malfet malfet added topic: performance topic category release notes: inductor ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request labels Aug 23, 2024
@desertfire
Copy link
Contributor

Let's trigger a dashboard run for this.

@malfet
Copy link
Contributor Author

malfet commented Aug 23, 2024

Let's trigger a dashboard run for this.

Sure, https://github.com/pytorch/pytorch/actions/runs/10529131469

[Edit] Realized I did this change before the split, so alas it's not really usable. Let's test in trunk

int32x4_t shift = vdupq_n_s32(16);
auto u16_low1 = vget_low_u16(u16_8);
auto u16_high1 = vget_high_u16(u16_8);
float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_low1), shift));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, but if input is interleaved then you can just do vectorized (input & 0xFF00) and the reinterpret, for upper half and save get_high, and movl instructions for upper half. For lower hafl you would still need those.

@malfet
Copy link
Contributor Author

malfet commented Aug 23, 2024

@pytorchbot merge

@malfet malfet changed the title [Compile] Add NEON implementation of bf16->fp32 cast [Compile] Add NEON implementation for bf16->fp32 cast Aug 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor Author

malfet commented Aug 23, 2024

@pytorchbot merge -f "This is weird: workflow dispatch jobs do not show up in the signal box, but still delay the merge"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
This changes assembly generated for the following routine
```cpp
void bfloat16tofloat(c10::BFloat16* in, float* out) {
        auto tmp0 = at::vec::Vectorized<c10::BFloat16>::loadu(in, 8);
        auto tmp1 = at::vec::convert<float>(tmp0);
        tmp1.store(out);
}
```
from
```asm
bfloat16tofloat(c10::BFloat16*, float*):
0000000000000034        stp     x29, x30, [sp, #-0x10]!
0000000000000038        mov     x29, sp
000000000000003c        sub     x9, sp, #0x90
0000000000000040        and     sp, x9, #0xffffffffffffffe0
0000000000000044        mov     x8, #0x0
0000000000000048        adrp    x9, 0 ; 0x0
000000000000004c        ldr     x9, [x9]
0000000000000050        ldr     x9, [x9]
0000000000000054        str     x9, [sp, #0x88]
0000000000000058        stp     xzr, xzr, [sp, #0x10]
000000000000005c        ldr     q0, [x0]
0000000000000060        str     q0, [sp]
0000000000000064        ldr     q1, [sp, #0x10]
0000000000000068        stp     q0, q1, [sp, #0x20]
000000000000006c        add     x9, sp, #0x40
0000000000000070        add     x10, sp, #0x20
0000000000000074        add     x11, x10, x8
0000000000000078        ldp     d0, d1, [x11]
000000000000007c        shll.4s v0, v0, #16
0000000000000080        shll.4s v1, v1, #16
0000000000000084        stp     q0, q1, [x9], #0x20
0000000000000088        add     x8, x8, #0x10
000000000000008c        cmp     x8, #0x20
0000000000000090        b.ne    0x74
0000000000000094        add     x8, sp, #0x40
0000000000000098        ld1.4s  { v0, v1 }, [x8]
000000000000009c        st1.4s  { v0, v1 }, [x1]
00000000000000a0        ldr     x8, [sp, #0x88]
00000000000000a4        adrp    x9, 0 ; 0x0
00000000000000a8        ldr     x9, [x9]
00000000000000ac        ldr     x9, [x9]
00000000000000b0        cmp     x9, x8
00000000000000b4        b.ne    0xc4
00000000000000b8        mov     sp, x29
00000000000000bc        ldp     x29, x30, [sp], #0x10
00000000000000c0        ret
00000000000000c4        bl      0xc4
```
to
```asm
bfloat16tofloat(c10::BFloat16*, float*):
0000000000000034        ldr     q0, [x0]
0000000000000038        shll.4s v1, v0, #16
000000000000003c        shll2.4s        v2, v0, #16
0000000000000040        st1.4s  { v1, v2 }, [x1]
0000000000000044        ret
```

And as result speeds up `python3 torchchat.py generate stories110M --num-samples 3 --compile --device cpu --dtype bfloat16` from 33 to 90 tokens/sec

Pull Request resolved: #134297
Approved by: https://github.com/kimishpatel
@github-actions github-actions bot deleted the malfet/aarch64-speedup-bfloat16-float-convert branch October 1, 2024 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/linux-aarch64 linux aarch64 CI workflow ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) release notes: inductor topic: performance topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants