1
1
import framework
2
+ import itertools
3
+ import numpy as np
4
+ import torch
2
5
3
-
4
- class ElementMulBench (framework .Benchmark ):
6
+ # A template class for elementwise operations.
7
+ # A derived class will override the class instance to customize its behavior.
8
+ class ElementBench (framework .Benchmark ):
9
+ # List of customization class variables.
10
+ op_str = None
11
+ binary_op_pt_func = None
12
+ binary_op_np_func = None
13
+ unary_op_pt_func = None
14
+ unary_op_np_func = None
15
+ split_input = True
5
16
def __init__ (self , mode , device , N ):
6
17
super ().__init__ (mode , device )
7
18
self .N = N
@@ -11,27 +22,60 @@ def __init__(self, mode, device, N):
11
22
self .d4 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
12
23
self .inputs = [self .d1 , self .d2 , self .d3 , self .d4 ]
13
24
25
+ def _eval (self , d1 , d2 , d3 , d4 , binary_op , unary_op ):
26
+ if not binary_op :
27
+ binary_op = lambda x , y : x + y
28
+ if not unary_op :
29
+ unary_op = lambda x : x
30
+ if self .split_input :
31
+ d1 = unary_op (d1 )
32
+ d2 = unary_op (d2 )
33
+ d3 = unary_op (d3 )
34
+ d4 = unary_op (d4 )
35
+ else :
36
+ d2 = unary_op (d1 + 0.001 )
37
+ d3 = unary_op (d1 + 0.002 )
38
+ d4 = unary_op (d1 + 0.003 )
39
+ d1 = unary_op (d1 )
40
+ a = binary_op (d1 , d2 )
41
+ b = binary_op (d3 , d4 )
42
+ c = a + b
43
+ return c
44
+
14
45
def forward (self , d1 , d2 , d3 , d4 ):
15
- y = d1 * d2 + d3 * d4
16
- return y
46
+ binary_op = self .__class__ .binary_op_pt_func
47
+ unary_op = self .__class__ .unary_op_pt_func
48
+ return self ._eval (d1 , d2 , d3 , d4 , binary_op , unary_op )
17
49
18
50
def reference (self ):
19
- return self .numpy (self .d1 ) * self .numpy (self .d2 ) + self .numpy (self .d3 ) * self .numpy (self .d4 )
51
+ binary_op = self .__class__ .binary_op_np_func
52
+ unary_op = self .__class__ .unary_op_np_func
53
+ [d1 , d2 , d3 , d4 ] = [self .numpy (d ) for d in [self .d1 , self .d2 , self .d3 , self .d4 ]]
54
+ return self ._eval (d1 , d2 , d3 , d4 , binary_op , unary_op )
20
55
21
56
def config (self ):
22
57
return [self .N ]
23
58
24
- @staticmethod
25
- def module ():
26
- return 'element_mul'
59
+ @classmethod
60
+ def module (cls ):
61
+ return 'element_' + cls . op_str
27
62
28
63
def memory_workload (self ):
64
+ input_count = len (self .inputs )
29
65
if self .mode == 'fwd' :
30
- sol_count = 4 + 1
31
- algorithmic_count = 3 + 1
66
+ if self .split_input :
67
+ sol_count = input_count + 1
68
+ algorithmic_count = input_count + 1
69
+ else :
70
+ sol_count = 1 + 1
71
+ algorithmic_count = 1 + 1
32
72
else :
33
- sol_count = (4 + 1 ) + (1 + 4 )
34
- algorithmic_count = (4 + 1 ) + ((2 + 1 ) * 4 )
73
+ if self .split_input :
74
+ sol_count = (input_count + 1 ) + (1 + input_count )
75
+ algorithmic_count = (input_count + 1 ) + ((2 + 1 ) * input_count )
76
+ else :
77
+ sol_count = 1 + 1
78
+ algorithmic_count = 1 + 1
35
79
36
80
buffer_size = self .N * 4
37
81
return {'sol' : buffer_size * sol_count , 'algorithmic' : buffer_size * algorithmic_count }
@@ -41,4 +85,56 @@ def default_configs():
41
85
return [[1 << 27 ]]
42
86
43
87
44
- framework .register_benchmark_class (ElementMulBench )
88
+ def register_element_ops ():
89
+ binary_op_list = [
90
+ ["mul" , lambda a , b : a * b ],
91
+ ["add" , lambda a , b : a + b ],
92
+ ["sub" , lambda a , b : a - b ],
93
+ ["div" , lambda a , b : a / (b + 1e-4 )],
94
+ ["pow" , lambda a , b : torch .pow (a , b ), lambda a , b : np .power (a , b )], # no fuson triggered
95
+ ["max" , lambda a , b : torch .max (a , b ), lambda a , b : np .maximum (a , b )],
96
+ ["min" , lambda a , b : torch .min (a , b ), lambda a , b : np .minimum (a , b )],
97
+ ]
98
+
99
+ unary_op_list = [
100
+ ["exp" , lambda x : torch .exp (x ), lambda x : np .exp (x )],
101
+ ["sin" , lambda x : torch .sin (x ), lambda x : np .sin (x )],
102
+ ["cos" , lambda x : torch .cos (x ), lambda x : np .cos (x )],
103
+ ]
104
+
105
+ for split_input , binary_op in itertools .product ([True , False ], binary_op_list ):
106
+ # Make a copy of ElementBench
107
+ if len (binary_op ) == 2 :
108
+ [op_str , op_pt_func ] = binary_op
109
+ op_np_func = op_pt_func
110
+ elif len (binary_op ) == 3 :
111
+ [op_str , op_pt_func , op_np_func ] = binary_op
112
+ split_str = 'split' if split_input else 'shared'
113
+ op_str = split_str + '_' + op_str
114
+ bm_cls = type ('ElementBench_' + op_str , (ElementBench ,), {})
115
+ bm_cls .op_str = op_str
116
+ bm_cls .binary_op_pt_func = op_pt_func
117
+ bm_cls .binary_op_np_func = op_np_func
118
+ bm_cls .split_input = split_input
119
+ framework .register_benchmark_class (bm_cls )
120
+
121
+ for split_input , unary_op in itertools .product ([True , False ], unary_op_list ):
122
+ # Make a copy of ElementBench
123
+ if len (unary_op ) == 2 :
124
+ [op_str , op_pt_func ] = unary_op
125
+ op_np_func = op_pt_func
126
+ elif len (unary_op ) == 3 :
127
+ [op_str , op_pt_func , op_np_func ] = unary_op
128
+ split_str = 'split' if split_input else 'shared'
129
+ op_str = split_str + '_' + op_str
130
+ bm_cls = type ('ElementBench_' + op_str , (ElementBench ,), {})
131
+ bm_cls .op_str = op_str
132
+ bm_cls .unary_op_pt_func = op_pt_func
133
+ bm_cls .unary_op_np_func = op_np_func
134
+ bm_cls .split_input = split_input
135
+ framework .register_benchmark_class (bm_cls )
136
+
137
+
138
+ #framework.register_benchmark_class(ElementMulBench)
139
+ register_element_ops ()
140
+
0 commit comments