@@ -7,13 +7,16 @@ def __init__(self, mode, device, N):
7
7
self .N = N
8
8
self .d1 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
9
9
self .d2 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
10
+ self .d3 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
11
+ self .d4 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
12
+ self .inputs = [self .d1 , self .d2 , self .d3 , self .d4 ]
10
13
11
- def forward (self ):
12
- y = self . mul ( self . d1 , self . d2 )
14
+ def forward (self , d1 , d2 , d3 , d4 ):
15
+ y = d1 * d2 + d3 * d4
13
16
return y
14
17
15
18
def reference (self ):
16
- return self .numpy (self .d1 ) * self .numpy (self .d2 )
19
+ return self .numpy (self .d1 ) * self .numpy (self .d2 ) + self . numpy ( self . d3 ) * self . numpy ( self . d4 )
17
20
18
21
def config (self ):
19
22
return [self .N ]
@@ -24,11 +27,11 @@ def module():
24
27
25
28
def memory_workload (self ):
26
29
if self .mode == 'fwd' :
27
- sol_count = 2 + 1
28
- algorithmic_count = 2 + 1
30
+ sol_count = 4 + 1
31
+ algorithmic_count = 3 + 1
29
32
else :
30
- sol_count = (2 + 1 ) + (1 + 2 )
31
- algorithmic_count = (2 + 1 ) + ((2 + 1 ) + ( 2 + 1 ) )
33
+ sol_count = (4 + 1 ) + (1 + 4 )
34
+ algorithmic_count = (4 + 1 ) + ((2 + 1 ) * 4 )
32
35
33
36
buffer_size = self .N * 4
34
37
return {'sol' : buffer_size * sol_count , 'algorithmic' : buffer_size * algorithmic_count }
0 commit comments