Skip to content

Commit 2f2063b

Browse files
committed
Added k<256 quantile estimate.
1 parent 98cbc4b commit 2f2063b

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

bitsandbytes/functional.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
182182

183183

184184

185-
def create_dynamic_map(signed=True, n=7):
185+
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
186186
"""
187187
Creates the dynamic quantiztion map.
188188
@@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7):
203203
# these are additional items that come from the case
204204
# where all the exponent bits are zero and no
205205
# indicator bit is present
206-
additional_items = 2 ** (7 - n) - 1
206+
non_sign_bits = total_bits - (1 if signed else 0)
207+
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
207208
if not signed:
208209
additional_items = 2 * additional_items
209-
for i in range(n):
210-
fraction_items = (
211-
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
212-
)
210+
for i in range(max_exponent_bits):
211+
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
213212
boundaries = torch.linspace(0.1, 1, fraction_items)
214213
means = (boundaries[:-1] + boundaries[1:]) / 2.0
215-
data += ((10 ** (-(n - 1) + i)) * means).tolist()
214+
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
216215
if signed:
217-
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
216+
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
218217

219-
if additional_items > 0:
220-
boundaries = torch.linspace(0.1, 1, additional_items + 1)
221-
means = (boundaries[:-1] + boundaries[1:]) / 2.0
222-
data += ((10 ** (-(n - 1) + i)) * means).tolist()
223-
if signed:
224-
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
218+
if additional_items > 0:
219+
boundaries = torch.linspace(0.1, 1, additional_items + 1)
220+
means = (boundaries[:-1] + boundaries[1:]) / 2.0
221+
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
222+
if signed:
223+
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
225224

226225
data.append(0)
227226
data.append(1.0)
227+
228+
gap = 256 - len(data)
229+
for i in range(gap):
230+
data.append(0)
231+
228232
data.sort()
229233
return Tensor(data)
230234

@@ -371,9 +375,7 @@ def nvidia_transform(
371375
return out, new_state
372376

373377

374-
def estimate_quantiles(
375-
A: Tensor, out: Tensor = None, offset: float = 1 / 512
376-
) -> Tensor:
378+
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
377379
'''
378380
Estimates 256 equidistant quantiles on the input tensor eCDF.
379381
@@ -393,25 +395,36 @@ def estimate_quantiles(
393395
out : torch.Tensor
394396
Tensor with the 256 estimated quantiles.
395397
offset : float
396-
The offset for the first and last quantile from 0 and 1. Default: 1/512
398+
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
399+
num_quantiles : int
400+
The number of equally spaced quantiles.
397401
398402
Returns
399403
-------
400404
torch.Tensor:
401405
The 256 quantiles in float32 datatype.
402406
'''
407+
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
408+
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
409+
if num_quantiles < 256 and offset == 1/(512):
410+
# override default arguments
411+
offset = 1/(2*num_quantiles)
412+
403413
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
404414
is_on_gpu([A, out])
415+
device = pre_call(A.device)
405416
if A.dtype == torch.float32:
406-
lib.cestimate_quantiles_fp32(
407-
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
408-
)
417+
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
409418
elif A.dtype == torch.float16:
410-
lib.cestimate_quantiles_fp16(
411-
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
412-
)
419+
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
413420
else:
414421
raise NotImplementedError(f"Not supported data type {A.dtype}")
422+
post_call(device)
423+
424+
if num_quantiles < 256:
425+
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
426+
out = out[idx]
427+
415428
return out
416429

417430

tests/test_functional.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import einops
77
import pytest
88
import torch
9+
import numpy as np
910

1011
import bitsandbytes as bnb
1112
from bitsandbytes import functional as F
13+
from scipy.stats import norm
1214

1315
torch.set_printoptions(
1416
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
@@ -2094,19 +2096,34 @@ def test_fp8_quant():
20942096

20952097
def test_few_bit_quant():
20962098

2099+
print('')
20972100
for bits in range(2, 9):
2098-
for method in ['linear', 'fp8']:
2101+
print('='*30, bits, '='*30)
2102+
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
2103+
abserrs = []
2104+
relerrs = []
20992105
code = None
21002106
if method == 'linear':
21012107
code = F.create_linear_map(True, bits=bits).cuda()
21022108
elif method == 'fp8':
21032109
ebits = math.ceil(bits/2)
21042110
pbits = bits-ebits-1
21052111
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
2106-
print(ebits, pbits, bits)
2107-
print(code)
2112+
elif method == 'dynamic':
2113+
code = F.create_dynamic_map(True, bits-0, bits).cuda()
2114+
elif method == 'quantile':
2115+
values = torch.randn(2048, 2048, device='cuda')
2116+
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
2117+
gap = 256-q.numel()
2118+
q = q.tolist()
2119+
for i in range(gap):
2120+
q.append(0)
2121+
q = torch.Tensor(q).cuda()
2122+
2123+
q /= q.abs().max()
2124+
code, idx = torch.sort(q)
2125+
print(method, (code==0).sum())
21082126
assert code.numel() == 256
2109-
print(bits)
21102127
for i in range(10):
21112128

21122129
values = torch.randn(1, 32, device='cuda')
@@ -2127,11 +2144,25 @@ def test_few_bit_quant():
21272144
v2 = F.dequantize(q2, S2)
21282145

21292146
idx = torch.isclose(q1.int(), q2.int())
2147+
err2 = torch.abs(v2-values)
2148+
abserrs.append(err2.mean().item())
2149+
relerrs.append((err2/(1e-10+values).abs()).mean().item())
21302150
if idx.sum():
21312151
# some weird cases
21322152
err1 = torch.abs(v1-values).mean()
2133-
err2 = torch.abs(v2-values).mean()
2134-
assert err2 <= err1
2153+
assert err2.mean() <= err1
21352154

21362155
else:
21372156
torch.testing.assert_allclose(q1, q2)
2157+
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
2158+
2159+
2160+
def test_kbit_quantile_estimation():
2161+
for i in range(100):
2162+
data = torch.randn(1024, 1024, device='cuda')
2163+
for bits in range(2, 9):
2164+
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
2165+
val1 = torch.Tensor(norm.ppf(p)).cuda()
2166+
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
2167+
err = torch.abs(val1-val2).mean()
2168+
assert err < 0.035

0 commit comments

Comments
 (0)