3
3
import torch
4
4
import triton
5
5
import triton .language as tl
6
- from torch .cuda . amp import custom_bwd , custom_fwd
6
+ from torch .amp import custom_bwd , custom_fwd
7
7
8
8
9
9
def make_dequant_configs (block_sizes , num_warps ):
@@ -92,7 +92,7 @@ def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
92
92
93
93
out = torch .empty ((infeatures , outfeatures ), device = "cuda" , dtype = torch .float16 )
94
94
numels = out .numel ()
95
- maxq = 2 ** bits - 1 if maxq is None else maxq
95
+ maxq = 2 ** bits - 1 if maxq is None else maxq
96
96
grid = lambda meta : (triton .cdiv (numels , meta ["X_BLOCK" ]),) # noqa: E731
97
97
98
98
dequant_kernel_248 [grid ](
@@ -119,15 +119,15 @@ def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq=None, tra
119
119
120
120
class QuantLinearFunction (torch .autograd .Function ):
121
121
@staticmethod
122
- @custom_fwd
122
+ @custom_fwd ( device_type = "cuda" )
123
123
def forward (ctx , input , qweight , scales , qzeros , g_idx , bits , maxq ):
124
124
output = quant_matmul_248 (input , qweight , scales , qzeros , g_idx , bits , maxq )
125
125
ctx .save_for_backward (qweight , scales , qzeros , g_idx )
126
126
ctx .bits , ctx .maxq = bits , maxq
127
127
return output
128
128
129
129
@staticmethod
130
- @custom_bwd
130
+ @custom_bwd ( device_type = "cuda" )
131
131
def backward (ctx , grad_output ):
132
132
qweight , scales , qzeros , g_idx = ctx .saved_tensors
133
133
bits , maxq = ctx .bits , ctx .maxq
0 commit comments