Skip to content

MX single node performance tracker #1768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vkuzo opened this issue Feb 24, 2025 · 2 comments
Open

MX single node performance tracker #1768

vkuzo opened this issue Feb 24, 2025 · 2 comments

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 24, 2025

This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.

training performance summary

As of 2025-03-27

  • e2e pretraining speedup vs bf16 + compile on LLaMa 3 8B, 8 B200 GPUs, torchtitan with default settings
    • 🟢 float8 tensorwise: 1.19x
    • 🔴 mxfp8: 1.09x (should be similar to tensorwise's 1.19x once we fix all the issues). Right now scaling/casting to mx is slow
  • 🟢 gemm speedup: cuBLAS mxfp8 gemm is 2x to 3x faster than bf16 - done for now
  • 🟢 mx casting to dim0 with torch.compile achieves up to 67% of peak mem bw - done for now
  • 🔴 mx casting to dim1 is our main performance gap
  • 🔲 mx casting to dim0 + dim1 at the same time is postponed for now until we make the individual dim0 and dim1 kernels better

invididual components

system overview (for training)

# There are three gemms in a forward + backward of a Linear layer:
#
# 1.       input @ weight_t    = output     (forward pass)
# 2. grad_output @ weight      = grad_input (backward pass)
# 3.     input_t @ grad_output = grad_weight (backward pass)
# 
# in Python pseudocode, we want the following (for mxfp8):

# forward pass

# inputs are in high precision
x_hp, w_hp = ...

# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)

# backward pass

# inputs are in high precision
x_hp, w_hp, go_hp = ...

# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)

# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)

We want:

  1. the mx gemm to be fast
  2. the cast from high precision to mx (to_mx in pseudocode above) to be fast
  3. the cast from high precision to mx to be fused to preceding/subsequent ops where possible

gemm kernel

Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)

kernel wrapper current TFLOPs peak TFLOPs notes
mxfp8 cuBLAS torch._scaled_mm TBD 4.25 petaFLOPs landed, pytorch/pytorch#147548
mxfp8 CUTLASS torchao.ops.mx_fp8_bf16 TBD 4.25 petaFLOPs landed, #1637
mxfp4 CUTLASS torchao.ops.mx_fp4_bf16 TBD 9.0 petaFLOPs landed, #1661
nvfp4 cuBLAS torch._scaled_mm TBD 9.0 petaFLOPs in progress, pytorch/pytorch#148792

Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.

scaling/casting kernels

Our current plan is to use torch.compile, same as we are doing with float8.

e2e training performance

From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).

  • we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
  • [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949

e2e inference performance

  • need an inference roofline
  • need to decide where to benchmark
@vkuzo vkuzo self-assigned this Feb 24, 2025
@vkuzo vkuzo changed the title MX training single GPU performance MX training performance tracker Feb 24, 2025
@vkuzo vkuzo changed the title MX training performance tracker MX training single node performance tracker Feb 24, 2025
@vkuzo vkuzo changed the title MX training single node performance tracker MX single node performance tracker Feb 24, 2025
@syed-ahmed
Copy link
Contributor

For the to_mxfp8_dim1_kernel, the main performance blocker is the shared memory bank conflicts, which arises from the transpose followed by the store:

col_normalized = tl.trans(col_normalized_t)
col_normalized = col_normalized.to(tl.float8e4nv)
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)

The quantized mxfp8 data will be first stored to shared memory. Then it will be loaded from shared memory (using SASS LDS.U8), which is vectorized, and then stored back to global memory. During loading data from shared memory, a lot of bank conflicts happen as seen in the NCU profile.

Image

The above was generated using:

export USE_IR_LOC=ttgir
ncu --set full --clock-control none --import-source on -f -o report python benchmarks/mx_formats/cast_bench.py --mode dim1_mx_triton --M 16384 --K 16384

@drisspg
Copy link
Contributor

drisspg commented May 22, 2025

FP8 vs MXFP8 Benchmark Comparison

References

  • MXFP8: https://fburl.com/s2g726a1
    CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters mx --mx.recipe_name "mxfp8" --profiling.enable_profiling

step: 70 loss: 6.9682 memory: 35.94GiB(20.15%) tps: 12,798 tflops: 741.17 mfu: 16.47%**

  • FP8 Pertensor: https://fburl.com/99cbavtm
  • CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters float8 --profiling.enable_profiling

step: 70 loss: 7.0653 memory: 35.92GiB(20.14%) tps: 13,674 tflops: 791.90 mfu: 17.60%

  • BF16: https://fburl.com/29px2raz
  • CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --profiling.enable_profiling
  • step: 70 loss: 6.9436 memory: 36.00GiB(20.18%) tps: 11,388 tflops: 659.54 mfu: 14.66%

QKV Operations (including Output Projection)

Operation MXFP8 (μs) FP8 (μs) Performance
First FP8 GEMM 152 144 FP8 is 5% faster
Second FP8 GEMM 52 61 MXFP8 is 15% faster
Third FP8 GEMM 51 53 Noise (negligible difference)
Output Projection 123 107 FP8 is ~13% faster

MLP Operations

Test Run MXFP8 (μs) FP8 (μs)
mat 1 403 377
mat 2 383 376
mat 3 400 364

Average Performance:

  • FP8: ~395 μs
  • MXFP8: ~372 μs
  • MXFP8 is ~6% faster on average

TLDR is forward has some attribution of speed difference from gemm perf differences

Macro Level

Fwd for MXFP8: 2.2 ms
Fwd for FP8 PerTensor: 1.33 ms

FWD FP8 is 0.6x the latency

Bwd for MXFP8: is about 12.120 ms
Bwd for FP8 PerTensor: 9.8 ms

BWD FP8 is 0.8x the latency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants