Skip to content

Commit 9642333

Browse files
zheng-xqMikhail Zolotukhin
authored andcommitted
Add elementwise benchmarks and comparisons. (pytorch#155)
1 parent 4998cc6 commit 9642333

File tree

5 files changed

+140
-17
lines changed

5 files changed

+140
-17
lines changed

benchmarks/tensorexpr/elementwise.py

Lines changed: 109 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import framework
2+
import itertools
3+
import numpy as np
4+
import torch
25

3-
4-
class ElementMulBench(framework.Benchmark):
6+
# A template class for elementwise operations.
7+
# A derived class will override the class instance to customize its behavior.
8+
class ElementBench(framework.Benchmark):
9+
# List of customization class variables.
10+
op_str = None
11+
binary_op_pt_func = None
12+
binary_op_np_func = None
13+
unary_op_pt_func = None
14+
unary_op_np_func = None
15+
split_input = True
516
def __init__(self, mode, device, N):
617
super().__init__(mode, device)
718
self.N = N
@@ -11,27 +22,60 @@ def __init__(self, mode, device, N):
1122
self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad)
1223
self.inputs = [self.d1, self.d2, self.d3, self.d4]
1324

25+
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
26+
if not binary_op:
27+
binary_op = lambda x, y: x + y
28+
if not unary_op:
29+
unary_op = lambda x: x
30+
if self.split_input:
31+
d1 = unary_op(d1)
32+
d2 = unary_op(d2)
33+
d3 = unary_op(d3)
34+
d4 = unary_op(d4)
35+
else:
36+
d2 = unary_op(d1 + 0.001)
37+
d3 = unary_op(d1 + 0.002)
38+
d4 = unary_op(d1 + 0.003)
39+
d1 = unary_op(d1)
40+
a = binary_op(d1, d2)
41+
b = binary_op(d3, d4)
42+
c = a + b
43+
return c
44+
1445
def forward(self, d1, d2, d3, d4):
15-
y = d1 * d2 + d3 * d4
16-
return y
46+
binary_op = self.__class__.binary_op_pt_func
47+
unary_op = self.__class__.unary_op_pt_func
48+
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
1749

1850
def reference(self):
19-
return self.numpy(self.d1) * self.numpy(self.d2) + self.numpy(self.d3) * self.numpy(self.d4)
51+
binary_op = self.__class__.binary_op_np_func
52+
unary_op = self.__class__.unary_op_np_func
53+
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
54+
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
2055

2156
def config(self):
2257
return [self.N]
2358

24-
@staticmethod
25-
def module():
26-
return 'element_mul'
59+
@classmethod
60+
def module(cls):
61+
return 'element_' + cls.op_str
2762

2863
def memory_workload(self):
64+
input_count = len(self.inputs)
2965
if self.mode == 'fwd':
30-
sol_count = 4 + 1
31-
algorithmic_count = 3 + 1
66+
if self.split_input:
67+
sol_count = input_count + 1
68+
algorithmic_count = input_count + 1
69+
else:
70+
sol_count = 1 + 1
71+
algorithmic_count = 1 + 1
3272
else:
33-
sol_count = (4 + 1) + (1 + 4)
34-
algorithmic_count = (4 + 1) + ((2 + 1) * 4)
73+
if self.split_input:
74+
sol_count = (input_count + 1) + (1 + input_count)
75+
algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
76+
else:
77+
sol_count = 1 + 1
78+
algorithmic_count = 1 + 1
3579

3680
buffer_size = self.N * 4
3781
return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count}
@@ -41,4 +85,56 @@ def default_configs():
4185
return [[1 << 27]]
4286

4387

44-
framework.register_benchmark_class(ElementMulBench)
88+
def register_element_ops():
89+
binary_op_list = [
90+
["mul", lambda a, b: a * b],
91+
["add", lambda a, b: a + b],
92+
["sub", lambda a, b: a - b],
93+
["div", lambda a, b: a / (b + 1e-4)],
94+
["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered
95+
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
96+
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
97+
]
98+
99+
unary_op_list = [
100+
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
101+
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
102+
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
103+
]
104+
105+
for split_input, binary_op in itertools.product([True, False], binary_op_list):
106+
# Make a copy of ElementBench
107+
if len(binary_op) == 2:
108+
[op_str, op_pt_func] = binary_op
109+
op_np_func = op_pt_func
110+
elif len(binary_op) == 3:
111+
[op_str, op_pt_func, op_np_func] = binary_op
112+
split_str = 'split' if split_input else 'shared'
113+
op_str = split_str + '_' + op_str
114+
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
115+
bm_cls.op_str = op_str
116+
bm_cls.binary_op_pt_func = op_pt_func
117+
bm_cls.binary_op_np_func = op_np_func
118+
bm_cls.split_input = split_input
119+
framework.register_benchmark_class(bm_cls)
120+
121+
for split_input, unary_op in itertools.product([True, False], unary_op_list):
122+
# Make a copy of ElementBench
123+
if len(unary_op) == 2:
124+
[op_str, op_pt_func] = unary_op
125+
op_np_func = op_pt_func
126+
elif len(unary_op) == 3:
127+
[op_str, op_pt_func, op_np_func] = unary_op
128+
split_str = 'split' if split_input else 'shared'
129+
op_str = split_str + '_' + op_str
130+
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
131+
bm_cls.op_str = op_str
132+
bm_cls.unary_op_pt_func = op_pt_func
133+
bm_cls.unary_op_np_func = op_np_func
134+
bm_cls.split_input = split_input
135+
framework.register_benchmark_class(bm_cls)
136+
137+
138+
#framework.register_benchmark_class(ElementMulBench)
139+
register_element_ops()
140+

benchmarks/tensorexpr/framework.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def forward(self):
2424

2525
def check(self):
2626
np.testing.assert_allclose(
27-
self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-7)
27+
self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-2)
2828

2929
def config(self):
3030
'''returns an array for the current benchmark configs
@@ -81,7 +81,6 @@ def __init__(self, mode, device):
8181
method_engine = getattr(self.engine, method)
8282
setattr(self, method, method_engine)
8383

84-
8584
def rand(self, shape, device=None, requires_grad=False):
8685
v = self.engine.rand(shape, device=device, requires_grad=requires_grad)
8786
if requires_grad:

test/test_tensorexpr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_addcmul(x, y, z, w):
9898

9999
x = traced(rand_a, rand_b, rand_c, rand_d)
100100
y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
101-
np.testing.assert_allclose(x.numpy(), y.numpy())
101+
np.testing.assert_allclose(x.numpy(), y.numpy(), atol=1e-6)
102102

103103

104104
def test_three_arg_cuda():
@@ -678,7 +678,7 @@ def test_relu(x, y):
678678
traced = torch.jit.trace(torch_fn, (ins, ins))
679679
x = traced(rand_a, rand_b)
680680
y = torch_fn(rand_a, rand_b)
681-
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
681+
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3)
682682
# nans
683683
traced = torch.jit.trace(torch_fn, (ins, ins))
684684
x = traced(nans, rand_b)

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,33 @@ void CudaPrinter::visit(const For* v) {
124124
}
125125
}
126126

127+
void CudaPrinter::visit(const Intrinsics* v) {
128+
std::string func_name;
129+
// TODO: handle other data types.
130+
switch (v->op_type()) {
131+
case IntrinsicsOp::kSin:
132+
func_name = "sinf";
133+
break;
134+
case IntrinsicsOp::kCos:
135+
func_name = "cosf";
136+
break;
137+
case IntrinsicsOp::kExp:
138+
func_name = "expf";
139+
break;
140+
default:
141+
IRPrinter::visit(v);
142+
return;
143+
}
144+
os() << func_name << "(";
145+
for (int i = 0; i < v->nparams(); i++) {
146+
if (i > 0) {
147+
os() << ", ";
148+
}
149+
os() << v->param(i);
150+
}
151+
os() << ")";
152+
}
153+
127154
void CudaPrinter::visit(const Load* v) {
128155
// TODO: find a better metric in using ldg or not. Support different dtypes.
129156
os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")";

torch/csrc/jit/tensorexpr/cuda_codegen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class CudaPrinter : public IRPrinter {
3636
os() << ")";
3737
}
3838

39+
void visit(const Intrinsics* v);
3940
void visit(const For* v);
4041

4142
void visit(const Load* v);

0 commit comments

Comments
 (0)