Skip to content

Commit 8e05144

Browse files
committed
Add mx_fp4_kernel
stack-info: PR: #1661, branch: drisspg/stack/34
1 parent 0646800 commit 8e05144

File tree

3 files changed

+101
-79
lines changed

3 files changed

+101
-79
lines changed
Lines changed: 30 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,42 @@
11
import pytest
22
import torch
3-
43
from torchao.float8.float8_utils import compute_error
5-
from torchao.ops import mx_fp8_bf16
4+
from torchao.ops import mx_fp8_bf16, mx_fp4_bf16
65
from torchao.prototype.mx_formats.mx_tensor import MXTensor
76
from torchao.prototype.mx_formats.utils import to_blocked
8-
from torchao.utils import (
9-
TORCH_VERSION_AT_LEAST_2_4,
10-
is_sm_at_least_100,
11-
)
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
128

139
if not TORCH_VERSION_AT_LEAST_2_4:
1410
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1511

16-
17-
def run_matrix_test(M: int, K: int, N: int) -> float:
18-
"""
19-
Run matrix multiplication test with given dimensions.
20-
21-
Args:
22-
M, K, N: Matrix dimensions
23-
24-
Returns:
25-
float: SQNR (Signal-to-Quantization-Noise Ratio) value
26-
"""
12+
def run_matrix_test(M: int, K: int, N: int, format: str = "fp8") -> float:
2713
dtype = torch.bfloat16
2814
device = torch.device("cuda")
29-
30-
# Initialize matrices
15+
3116
a = torch.rand((M, K), dtype=dtype, device=device)
3217
b = torch.rand((N, K), dtype=dtype, device=device)
3318

34-
# Convert to MX format
35-
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
36-
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)
37-
38-
a_fp8 = a_mx._data
39-
b_fp8 = b_mx._data
40-
assert b_fp8.is_contiguous()
41-
b_fp8 = b_fp8.transpose(-1, -2)
42-
43-
# Get scales
44-
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
45-
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)
46-
47-
a_scale_block = to_blocked(a_scale_e8)
48-
b_scale_block = to_blocked(b_scale_e8)
19+
fmt = torch.float8_e4m3fn if format == "fp8" else "fp4_e2m1"
20+
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
21+
22+
a_mx = MXTensor.to_mx(a, fmt, 32)
23+
b_mx = MXTensor.to_mx(b, fmt, 32)
4924

50-
# Get reference output
51-
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
52-
-1, -2
53-
)
25+
a_data = a_mx._data
26+
b_data = b_mx._data
27+
assert b_data.is_contiguous()
28+
b_data = b_data.transpose(-1, -2)
5429

55-
# Run implementation
56-
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)
30+
a_scale = a_mx._scale_e8m0.view(M, K // 32)
31+
b_scale = b_mx._scale_e8m0.view(N, K // 32)
5732

58-
# Calculate metrics
59-
sqnr = compute_error(out_hp, out_e8_fp8)
33+
a_scale_block = to_blocked(a_scale)
34+
b_scale_block = to_blocked(b_scale)
6035

61-
return sqnr.item()
36+
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(-1, -2)
37+
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
6238

39+
return compute_error(out_hp, out).item()
6340

6441
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6542
@pytest.mark.skipif(
@@ -68,35 +45,17 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
6845
@pytest.mark.parametrize(
6946
"size",
7047
[
71-
# Small matrices
72-
(128, 128, 128),
73-
(256, 256, 256),
74-
(384, 384, 384),
75-
# Medium matrices
76-
(512, 512, 512),
77-
(640, 640, 640),
78-
(768, 768, 768),
79-
# Large matrices
80-
(896, 896, 896),
81-
(1024, 1024, 1024),
82-
# Very large matrices
83-
(8192, 8192, 8192),
84-
# Non-square matrices
85-
(128, 256, 384),
86-
(256, 384, 512),
87-
(384, 512, 640),
88-
# Non-aligned matrices
89-
(129, 256, 384),
90-
(256, 384, 536),
91-
(133, 512, 528),
48+
(128, 128, 128), (256, 256, 256), (384, 384, 384), # Small
49+
(512, 512, 512), (768, 768, 768), # Medium
50+
(1024, 1024, 1024), (8192, 8192, 8192), # Large
51+
(128, 256, 384), (256, 384, 512), # Non-square
52+
(129, 256, 384), (133, 512, 528), # Non-aligned
9253
],
9354
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
9455
)
95-
def test_matrix_multiplication(size):
96-
"""
97-
Test matrix multiplication with various dimensions.
98-
Verifies that the SQNR meets minimum quality threshold.
99-
"""
56+
@pytest.mark.parametrize("format", ["fp8", "fp4"])
57+
def test_matrix_multiplication(size, format):
10058
M, K, N = size
101-
sqnr = run_matrix_test(M, K, N)
102-
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
59+
sqnr = run_matrix_test(M, K, N, format)
60+
threshold = 80.0
61+
assert sqnr >= threshold, f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"

torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu renamed to torchao/csrc/cuda/mx_kernels/mx_fp_bf16.cu

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using namespace cute;
3434

3535
template<typename Element>
3636
constexpr int GetAlignment() {
37-
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
37+
if constexpr (std::is_same_v<Element, cutlass::mx_float4_t<cutlass::float_e2m1_t>>)
3838
return 32;
3939
return 16;
4040
}
@@ -46,11 +46,7 @@ template <typename ElementA,
4646
typename ClusterShape,
4747
typename PerSmTileShape_MNK>
4848
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
49-
at::Tensor& b_scale, at::Tensor& out) {
50-
int M = a.size(0);
51-
int K = a.size(1);
52-
int N = b.size(1);
53-
49+
at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) {
5450
// A matrix configuration
5551
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
5652
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
@@ -225,9 +221,12 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
225221
at::Tensor b_scale) {
226222
#if defined(BUILD_MX_KERNELS_CUTLASS)
227223
validate(a, b, a_scale, b_scale);
224+
auto M = a.size(0);
225+
auto K = a.size(1);
226+
auto N = b.size(1);
228227

229228
auto out =
230-
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
229+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
231230
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
232231
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
233232
using ElementD = cutlass::bfloat16_t;
@@ -236,16 +235,51 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
236235
using ClusterShape = Shape<_2,_1,_1>;
237236
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
238237

239-
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
238+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
240239
return out;
241240
#else
242241
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
243242
return at::Tensor{};
244243
#endif
245244
}
246245

246+
at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
247+
at::Tensor b_scale) {
248+
#if defined(BUILD_MX_KERNELS_CUTLASS)
249+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
250+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
251+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
252+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
253+
254+
auto M = a.size(0);
255+
auto K = a.size(1) * 2;
256+
auto N = b.size(1);
257+
258+
auto out =
259+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
260+
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
261+
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
262+
using ElementD = cutlass::bfloat16_t;
263+
264+
using MmaTileShape = Shape<_128,_128,_128>;
265+
using ClusterShape = Shape<_2,_1,_1>;
266+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
267+
268+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
269+
return out;
270+
#else
271+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
272+
return at::Tensor{};
273+
#endif
274+
}
275+
247276
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
248277
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
249278
}
279+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
280+
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
281+
}
282+
283+
250284

251285
} // namespace torchao

torchao/ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
2424
)
2525
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
26+
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
2627

2728

2829
def register_custom_op(name):
@@ -644,3 +645,31 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
644645
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
645646
"""Meta impl for mx_fp8_bf16"""
646647
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)
648+
649+
def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
650+
"""Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor.
651+
652+
This op is prototype subject to change.
653+
654+
Note: The mx scales are E8MO tensors stored in uint8 tensors (for now).
655+
The layout of the scales is very particular, see:
656+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
657+
658+
Args:
659+
A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
660+
B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
661+
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
662+
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
663+
664+
Returns:
665+
MXN bf16 Tensor
666+
667+
"""
668+
return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale)
669+
670+
671+
@register_custom_op("torchao::mx_fp4_bf16")
672+
def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
673+
"""Meta impl for mx_fp4_bf16"""
674+
# Assume that the contraction happens in the K dim thus M,N are perserved post bit pack
675+
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)

0 commit comments

Comments
 (0)