Skip to content

Commit ab07e54

Browse files
int3Jokeren
andauthored
[AUTOTUNER] Make autotuner take do_bench as a parameter (triton-lang#4496)
This makes the autotuner device-agnostic. Instead of having to know about the existence of e.g. do_bench_cudagraph, it can let the callers decide which backend-specific benchmarking function to use. See discussion in triton-lang#4417. --------- Co-authored-by: Keren Zhou <[email protected]>
1 parent 2cc227d commit ab07e54

File tree

8 files changed

+77
-26
lines changed

8 files changed

+77
-26
lines changed

python/test/unit/hopper/test_flashattention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
435435
@triton.testing.perf_report(configs)
436436
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
437437
assert mode in ['fwd', 'bwd']
438-
warmup = 25
439-
rep = 100
440438
if provider == "triton":
441439
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
442440
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
@@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
447445
o = fn()
448446
do = torch.randn_like(o)
449447
fn = lambda: o.backward(do, retain_graph=True)
450-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
448+
ms = triton.testing.do_bench(fn)
451449
return ms
452450
if provider == "flash":
453451
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
@@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
459457
o = fn()
460458
do = torch.randn_like(o)
461459
fn = lambda: o.backward(do, retain_graph=True)
462-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
460+
ms = triton.testing.do_bench(fn)
463461
return ms
464462

465463

python/test/unit/language/test_decorator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def test_triton_heuristic(device):
3333
src = torch.empty(N, device=device)
3434
dst = torch.zeros(N, device=device)
3535

36-
@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1)
36+
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1)
37+
38+
@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench)
3739
@triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs
3840
@triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args
3941
@triton.jit

python/test/unit/runtime/test_autotuner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import pytest
66

77

8+
def do_bench(kernel_call, quantiles):
9+
return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1)
10+
11+
812
@pytest.mark.parametrize('use_cuda_graph', [False, True])
913
def test_kwargs(use_cuda_graph: bool, device: str):
1014
M, N = 1024, 16
@@ -13,7 +17,7 @@ def test_kwargs(use_cuda_graph: bool, device: str):
1317

1418
configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})]
1519

16-
@triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
20+
@triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph, do_bench=do_bench)
1721
@triton.jit
1822
def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr):
1923
offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M)
@@ -34,7 +38,7 @@ def test_restore(device):
3438

3539
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
3640

37-
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1)
41+
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench)
3842
@triton.jit
3943
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
4044
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -64,7 +68,7 @@ def _post_hook(*args, exception):
6468
values["has_exception"] = True
6569
assert values["counter"] == 0
6670

67-
@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook)
71+
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook)
6872
@triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4})
6973
@triton.jit
7074
def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr):
@@ -115,7 +119,7 @@ def perf_model(*args, **kwargs):
115119
else:
116120
prune_configs_by = {'early_config_prune': early_config_prune}
117121

118-
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1)
122+
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench)
119123
@triton.jit
120124
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
121125
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

python/triton/backends/driver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
from abc import ABCMeta, abstractmethod, abstractclassmethod
2+
from typing import Callable, List, Protocol, Sequence
3+
4+
5+
class Benchmarker(Protocol):
6+
7+
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
8+
pass
29

310

411
class DriverBase(metaclass=ABCMeta):
@@ -11,6 +18,13 @@ def is_active(self):
1118
def get_current_target(self):
1219
pass
1320

21+
@abstractmethod
22+
def get_benchmarker(self) -> Benchmarker:
23+
"""
24+
Return the benchmarking function that this backend should use by default.
25+
"""
26+
raise NotImplementedError
27+
1428
def __init__(self) -> None:
1529
pass
1630

python/triton/runtime/autotuner.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import inspect
77
from typing import Dict
88

9-
from ..testing import do_bench, do_bench_cudagraph
109
from .jit import KernelInterface
1110
from .errors import OutOfResources
11+
from .driver import driver
1212

1313

1414
class Autotuner(KernelInterface):
@@ -24,9 +24,10 @@ def __init__(
2424
pre_hook=None,
2525
post_hook=None,
2626
prune_configs_by: Dict = None,
27-
warmup=25,
28-
rep=100,
27+
warmup=None,
28+
rep=None,
2929
use_cuda_graph=False,
30+
do_bench=None,
3031
):
3132
"""
3233
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -88,10 +89,36 @@ def _post_hook(args, exception):
8889
self.base_fn = fn
8990
while not inspect.isfunction(self.base_fn):
9091
self.base_fn = self.base_fn.fn
91-
self.num_warmups = warmup
92-
self.num_reps = rep
93-
import torch
94-
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
92+
93+
# If we got explicitly called via the old interface, raise a warning
94+
# and proceed with the old behavior.
95+
if warmup is not None or rep is not None or use_cuda_graph:
96+
import warnings
97+
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
98+
"https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning,
99+
stacklevel=1)
100+
if use_cuda_graph:
101+
from ..testing import do_bench_cudagraph
102+
self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
103+
kernel_call,
104+
rep=rep if rep is not None else 100,
105+
quantiles=quantiles,
106+
)
107+
return
108+
109+
import triton.testing
110+
self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
111+
kernel_call,
112+
warmup=warmup if warmup is not None else 25,
113+
rep=rep if rep is not None else 100,
114+
quantiles=quantiles,
115+
)
116+
return
117+
118+
if do_bench is None:
119+
self.do_bench = driver.active.get_benchmarker()
120+
else:
121+
self.do_bench = do_bench
95122

96123
def _bench(self, *args, config, **meta):
97124
from ..compiler.errors import CompileTimeAssertionFailure
@@ -125,9 +152,7 @@ def kernel_call():
125152
self.post_hook(args, exception=None)
126153

127154
try:
128-
if self.use_cuda_graph:
129-
return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
130-
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
155+
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
131156
except (OutOfResources, CompileTimeAssertionFailure):
132157
return [float("inf"), float("inf"), float("inf")]
133158

@@ -257,7 +282,7 @@ def __str__(self):
257282

258283

259284
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
260-
warmup=25, rep=100, use_cuda_graph=False):
285+
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
261286
"""
262287
Decorator for auto-tuning a :code:`triton.jit`'d function.
263288
@@ -305,10 +330,12 @@ def kernel(x_ptr, x_size, **META):
305330
'args': a list of arguments passed to the kernel.
306331
'exception': the exception raised by the kernel in case of a compilation or runtime error.
307332
:type post_hook: lambda args, exception
308-
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
333+
:param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
309334
:type warmup: int
310-
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
335+
:param rep: repetition time (in ms) to pass to benchmarking (deprecated).
311336
:type rep: int
337+
:param do_bench: a benchmark function to measure the time of each run.
338+
:type do_bench: lambda fn, quantiles
312339
"""
313340

314341
def decorator(fn):

python/tutorials/06-fused-attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
601601
@triton.testing.perf_report(configs)
602602
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
603603
assert mode in ["fwd", "bwd"]
604-
warmup = 25
605-
rep = 100
606604
dtype = torch.float16
607605
if "triton" in provider:
608606
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
@@ -620,15 +618,15 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev
620618
o = fn()
621619
do = torch.randn_like(o)
622620
fn = lambda: o.backward(do, retain_graph=True)
623-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
621+
ms = triton.testing.do_bench(fn)
624622
if provider == "flash":
625623
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
626624
fn = lambda: flash_attn_func(qkv, causal=causal)
627625
if mode == "bwd":
628626
o = fn()
629627
do = torch.randn_like(o)
630628
fn = lambda: o.backward(do, retain_graph=True)
631-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
629+
ms = triton.testing.do_bench(fn)
632630
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
633631
total_flops = 2 * flops_per_matmul
634632
if causal:

third_party/amd/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,7 @@ def get_current_target(self):
499499
arch = device_properties['arch']
500500
warp_size = device_properties['warpSize']
501501
return GPUTarget("hip", arch.split(':')[0], warp_size)
502+
503+
def get_benchmarker(self):
504+
from triton.testing import do_bench
505+
return do_bench

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,7 @@ def get_device_interface(self):
448448
def is_active():
449449
import torch
450450
return torch.cuda.is_available() and (torch.version.hip is None)
451+
452+
def get_benchmarker(self):
453+
from triton.testing import do_bench
454+
return do_bench

0 commit comments

Comments
 (0)