1
1
import pytest
2
2
import torch
3
-
4
3
from torchao .float8 .float8_utils import compute_error
5
- from torchao .ops import mx_fp8_bf16
4
+ from torchao .ops import mx_fp8_bf16 , mx_fp4_bf16
6
5
from torchao .prototype .mx_formats .mx_tensor import MXTensor
7
6
from torchao .prototype .mx_formats .utils import to_blocked
8
- from torchao .utils import (
9
- TORCH_VERSION_AT_LEAST_2_4 ,
10
- is_sm_at_least_100 ,
11
- )
7
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_sm_at_least_100
12
8
13
9
if not TORCH_VERSION_AT_LEAST_2_4 :
14
10
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
15
11
16
-
17
- def run_matrix_test (M : int , K : int , N : int ) -> float :
18
- """
19
- Run matrix multiplication test with given dimensions.
20
-
21
- Args:
22
- M, K, N: Matrix dimensions
23
-
24
- Returns:
25
- float: SQNR (Signal-to-Quantization-Noise Ratio) value
26
- """
12
+ def run_matrix_test (M : int , K : int , N : int , format : str = "fp8" ) -> float :
27
13
dtype = torch .bfloat16
28
14
device = torch .device ("cuda" )
29
-
30
- # Initialize matrices
15
+
31
16
a = torch .rand ((M , K ), dtype = dtype , device = device )
32
17
b = torch .rand ((N , K ), dtype = dtype , device = device )
33
18
34
- # Convert to MX format
35
- a_mx = MXTensor .to_mx (a , torch .float8_e4m3fn , 32 )
36
- b_mx = MXTensor .to_mx (b , torch .float8_e4m3fn , 32 )
37
-
38
- a_fp8 = a_mx ._data
39
- b_fp8 = b_mx ._data
40
- assert b_fp8 .is_contiguous ()
41
- b_fp8 = b_fp8 .transpose (- 1 , - 2 )
42
-
43
- # Get scales
44
- a_scale_e8 = a_mx ._scale_e8m0 .view (M , K // 32 )
45
- b_scale_e8 = b_mx ._scale_e8m0 .view (N , K // 32 )
46
-
47
- a_scale_block = to_blocked (a_scale_e8 )
48
- b_scale_block = to_blocked (b_scale_e8 )
19
+ fmt = torch .float8_e4m3fn if format == "fp8" else "fp4_e2m1"
20
+ mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
21
+
22
+ a_mx = MXTensor .to_mx (a , fmt , 32 )
23
+ b_mx = MXTensor .to_mx (b , fmt , 32 )
49
24
50
- # Get reference output
51
- out_hp = a_mx . to_dtype ( torch . bfloat16 ) @ b_mx .to_dtype ( torch . bfloat16 ). transpose (
52
- - 1 , - 2
53
- )
25
+ a_data = a_mx . _data
26
+ b_data = b_mx ._data
27
+ assert b_data . is_contiguous ()
28
+ b_data = b_data . transpose ( - 1 , - 2 )
54
29
55
- # Run implementation
56
- out_e8_fp8 = mx_fp8_bf16 ( a_fp8 , b_fp8 , a_scale_block , b_scale_block )
30
+ a_scale = a_mx . _scale_e8m0 . view ( M , K // 32 )
31
+ b_scale = b_mx . _scale_e8m0 . view ( N , K // 32 )
57
32
58
- # Calculate metrics
59
- sqnr = compute_error ( out_hp , out_e8_fp8 )
33
+ a_scale_block = to_blocked ( a_scale )
34
+ b_scale_block = to_blocked ( b_scale )
60
35
61
- return sqnr .item ()
36
+ out_hp = a_mx .to_dtype (torch .bfloat16 ) @ b_mx .to_dtype (torch .bfloat16 ).transpose (- 1 , - 2 )
37
+ out = mx_func (a_data , b_data , a_scale_block , b_scale_block )
62
38
39
+ return compute_error (out_hp , out ).item ()
63
40
64
41
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
65
42
@pytest .mark .skipif (
@@ -68,35 +45,17 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
68
45
@pytest .mark .parametrize (
69
46
"size" ,
70
47
[
71
- # Small matrices
72
- (128 , 128 , 128 ),
73
- (256 , 256 , 256 ),
74
- (384 , 384 , 384 ),
75
- # Medium matrices
76
- (512 , 512 , 512 ),
77
- (640 , 640 , 640 ),
78
- (768 , 768 , 768 ),
79
- # Large matrices
80
- (896 , 896 , 896 ),
81
- (1024 , 1024 , 1024 ),
82
- # Very large matrices
83
- (8192 , 8192 , 8192 ),
84
- # Non-square matrices
85
- (128 , 256 , 384 ),
86
- (256 , 384 , 512 ),
87
- (384 , 512 , 640 ),
88
- # Non-aligned matrices
89
- (129 , 256 , 384 ),
90
- (256 , 384 , 536 ),
91
- (133 , 512 , 528 ),
48
+ (128 , 128 , 128 ), (256 , 256 , 256 ), (384 , 384 , 384 ), # Small
49
+ (512 , 512 , 512 ), (768 , 768 , 768 ), # Medium
50
+ (1024 , 1024 , 1024 ), (8192 , 8192 , 8192 ), # Large
51
+ (128 , 256 , 384 ), (256 , 384 , 512 ), # Non-square
52
+ (129 , 256 , 384 ), (133 , 512 , 528 ), # Non-aligned
92
53
],
93
54
ids = lambda x : f"{ x [0 ]} x{ x [1 ]} x{ x [2 ]} " ,
94
55
)
95
- def test_matrix_multiplication (size ):
96
- """
97
- Test matrix multiplication with various dimensions.
98
- Verifies that the SQNR meets minimum quality threshold.
99
- """
56
+ @pytest .mark .parametrize ("format" , ["fp8" , "fp4" ])
57
+ def test_matrix_multiplication (size , format ):
100
58
M , K , N = size
101
- sqnr = run_matrix_test (M , K , N )
102
- assert sqnr >= 80.0 , f"SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
59
+ sqnr = run_matrix_test (M , K , N , format )
60
+ threshold = 80.0
61
+ assert sqnr >= threshold , f"{ format } SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
0 commit comments