Skip to content

Commit c8d768e

Browse files
committed
Update
[ghstack-poisoned]
1 parent be7806c commit c8d768e

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,19 @@ def run_around_tests():
5656
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5757
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
5858
@pytest.mark.parametrize("bias", [True, False])
59-
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
59+
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
6060
def test_linear_eager(elem_dtype, bias, input_shape):
6161
"""
6262
Smoke test for training linear module with mx weight
6363
"""
6464
grad_shape = list(input_shape)
65-
grad_shape[-1] = 6
65+
grad_shape[-1] = 128
6666

6767
m = nn.Sequential(
68-
nn.Linear(8, 6, bias=bias, device="cuda"),
68+
nn.Linear(256, 128, bias=bias, device="cuda"),
6969
)
7070
m_mx = copy.deepcopy(m)
71-
block_size = 2
71+
block_size = 32
7272
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
7373

7474
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
MXTensor,
2828
tensor_size_hp_to_fp4x2,
2929
)
30+
from transformer_nuggets.mx.to_blocked import (
31+
to_blocked,
32+
)
3033

3134
aten = torch.ops.aten
3235

@@ -63,13 +66,39 @@ def mx_mm(aten_op, args, kwargs=None):
6366
a = args[0]
6467
b = args[1]
6568
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
66-
a_hp = a.to_dtype(a._orig_dtype)
67-
b_hp = b.to_dtype(b._orig_dtype)
68-
# assert memory layout we expect to be required in hardware
69-
assert a_hp.is_contiguous()
70-
assert b_hp.t().is_contiguous()
71-
res = aten_op(a_hp, b_hp)
72-
return res
69+
70+
if a._data.dtype is torch.float8_e4m3fn and b._data.dtype is torch.float8_e4m3fn:
71+
72+
assert a._block_size == 32 and b._block_size == 32
73+
74+
a_s0 = a._scale_e8m0.reshape(a._data.shape[0], -1)
75+
a_s1 = to_blocked(a_s0)
76+
b_s0 = b._scale_e8m0.reshape(b._data.shape[1], -1)
77+
b_s1 = to_blocked(b_s0)
78+
out_mx_real = torch._scaled_mm(
79+
a._data,
80+
b._data,
81+
# a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel?
82+
b_s1,
83+
a_s1,
84+
None,
85+
None,
86+
a._orig_dtype,
87+
False,
88+
None,
89+
None,
90+
1, # DataType.E8M0
91+
)
92+
return out_mx_real
93+
94+
else:
95+
a_hp = a.to_dtype(a._orig_dtype)
96+
b_hp = b.to_dtype(b._orig_dtype)
97+
# assert memory layout we expect to be required in hardware
98+
assert a_hp.is_contiguous()
99+
assert b_hp.t().is_contiguous()
100+
res = aten_op(a_hp, b_hp)
101+
return res
73102

74103

75104
@implements([aten.t.default])

0 commit comments

Comments
 (0)