-
Notifications
You must be signed in to change notification settings - Fork 275
MX single node performance tracker #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
For the
The quantized mxfp8 data will be first stored to shared memory. Then it will be loaded from shared memory (using SASS LDS.U8), which is vectorized, and then stored back to global memory. During loading data from shared memory, a lot of bank conflicts happen as seen in the NCU profile. The above was generated using:
|
FP8 vs MXFP8 Benchmark ComparisonReferences
QKV Operations (including Output Projection)
MLP Operations
Average Performance:
TLDR is forward has some attribution of speed difference from gemm perf differences Macro LevelFwd for MXFP8: 2.2 ms FWD FP8 is 0.6x the latency Bwd for MXFP8: is about 12.120 ms BWD FP8 is 0.8x the latency |
Uh oh!
There was an error while loading. Please reload this page.
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.
training performance summary
As of 2025-03-27
invididual components
system overview (for training)
We want:
to_mx
in pseudocode above) to be fastgemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
torch._scaled_mm
torchao.ops.mx_fp8_bf16
torchao.ops.mx_fp4_bf16
torch._scaled_mm
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use
torch.compile
, same as we are doing with float8.float8_e8m0fnu
dtype was added to PyTorch in add thetorch.float8_e8m0fnu
dtype to PyTorch pytorch#147466, we need to updatetorchao
to use this dtype for scales, and then ensure that PT2 works e2e. TODO issueinput
,weight
,grad_output
. The kernels forinput
andgrad_output
should be fused with preceding/subsequent ops as appropriate. TODO issue.e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
e2e inference performance
The text was updated successfully, but these errors were encountered: