Skip to content

Commit cdec2a7

Browse files
fix deprecated (#447)
1 parent 6fea9e5 commit cdec2a7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

gptqmodel/nn_modules/triton_utils/dequant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import triton
55
import triton.language as tl
6-
from torch.cuda.amp import custom_bwd, custom_fwd
6+
from torch.amp import custom_bwd, custom_fwd
77

88

99
def make_dequant_configs(block_sizes, num_warps):
@@ -92,7 +92,7 @@ def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
9292

9393
out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16)
9494
numels = out.numel()
95-
maxq = 2**bits - 1 if maxq is None else maxq
95+
maxq = 2 ** bits - 1 if maxq is None else maxq
9696
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
9797

9898
dequant_kernel_248[grid](
@@ -119,15 +119,15 @@ def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq=None, tra
119119

120120
class QuantLinearFunction(torch.autograd.Function):
121121
@staticmethod
122-
@custom_fwd
122+
@custom_fwd(device_type="cuda")
123123
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
124124
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
125125
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
126126
ctx.bits, ctx.maxq = bits, maxq
127127
return output
128128

129129
@staticmethod
130-
@custom_bwd
130+
@custom_bwd(device_type="cuda")
131131
def backward(ctx, grad_output):
132132
qweight, scales, qzeros, g_idx = ctx.saved_tensors
133133
bits, maxq = ctx.bits, ctx.maxq

0 commit comments

Comments
 (0)