@@ -31,7 +31,7 @@ class ResNet():
31
31
def __init__ (self , layers = 50 ):
32
32
self .layers = layers
33
33
34
- def net (self , input , class_dim = 1000 , data_format = "NCHW" ):
34
+ def net (self , input , class_dim = 1000 , data_format = "NCHW" , fuse_bn_add_act = False ):
35
35
layers = self .layers
36
36
supported_layers = [18 , 34 , 50 , 101 , 152 ]
37
37
assert layers in supported_layers , \
@@ -77,7 +77,8 @@ def net(self, input, class_dim=1000, data_format="NCHW"):
77
77
num_filters = num_filters [block ],
78
78
stride = 2 if i == 0 and block != 0 else 1 ,
79
79
name = conv_name ,
80
- data_format = data_format )
80
+ data_format = data_format ,
81
+ fuse_bn_add_act = fuse_bn_add_act )
81
82
82
83
pool = fluid .layers .pool2d (
83
84
input = conv , pool_type = 'avg' , global_pooling = True , data_format = data_format )
@@ -97,7 +98,8 @@ def net(self, input, class_dim=1000, data_format="NCHW"):
97
98
stride = 2 if i == 0 and block != 0 else 1 ,
98
99
is_first = block == i == 0 ,
99
100
name = conv_name ,
100
- data_format = data_format )
101
+ data_format = data_format ,
102
+ fuse_bn_add_act = fuse_bn_add_act )
101
103
102
104
pool = fluid .layers .pool2d (
103
105
input = conv , pool_type = 'avg' , global_pooling = True , data_format = data_format )
@@ -155,7 +157,7 @@ def shortcut(self, input, ch_out, stride, is_first, name, data_format):
155
157
else :
156
158
return input
157
159
158
- def bottleneck_block (self , input , num_filters , stride , name , data_format ):
160
+ def bottleneck_block (self , input , num_filters , stride , name , data_format , fuse_bn_add_act ):
159
161
conv0 = self .conv_bn_layer (
160
162
input = input ,
161
163
num_filters = num_filters ,
@@ -171,26 +173,56 @@ def bottleneck_block(self, input, num_filters, stride, name, data_format):
171
173
act = 'relu' ,
172
174
name = name + "_branch2b" ,
173
175
data_format = data_format )
174
- conv2 = self .conv_bn_layer (
175
- input = conv1 ,
176
- num_filters = num_filters * 4 ,
177
- filter_size = 1 ,
178
- act = None ,
179
- name = name + "_branch2c" ,
180
- data_format = data_format )
176
+ if not fuse_bn_add_act :
177
+ conv2 = self .conv_bn_layer (
178
+ input = conv1 ,
179
+ num_filters = num_filters * 4 ,
180
+ filter_size = 1 ,
181
+ act = None ,
182
+ name = name + "_branch2c" ,
183
+ data_format = data_format )
184
+ short = self .shortcut (
185
+ input ,
186
+ num_filters * 4 ,
187
+ stride ,
188
+ is_first = False ,
189
+ name = name + "_branch1" ,
190
+ data_format = data_format )
181
191
182
- short = self .shortcut (
183
- input ,
184
- num_filters * 4 ,
185
- stride ,
186
- is_first = False ,
187
- name = name + "_branch1" ,
188
- data_format = data_format )
192
+ return fluid .layers .elementwise_add (
193
+ x = short , y = conv2 , act = 'relu' , name = name + ".add.output.5" )
194
+ else :
195
+ conv2 = fluid .layers .conv2d (
196
+ input = conv1 ,
197
+ num_filters = num_filters * 4 ,
198
+ filter_size = 1 ,
199
+ act = None ,
200
+ param_attr = ParamAttr (name = name + "_branch2c" + "_weights" ),
201
+ bias_attr = False ,
202
+ name = name + '_branch2c' + '.conv2d.output.1' ,
203
+ data_format = data_format )
204
+ short = self .shortcut (
205
+ input ,
206
+ num_filters * 4 ,
207
+ stride ,
208
+ is_first = False ,
209
+ name = name + "_branch1" ,
210
+ data_format = data_format )
211
+ name = name + "_branch2c"
212
+ bn_name = "bn" + name [3 :]
213
+ short = fluid .contrib .layers .fused_bn_add_act (
214
+ conv2 ,
215
+ short ,
216
+ param_attr = ParamAttr (name = bn_name + '_scale' ),
217
+ bias_attr = ParamAttr (bn_name + '_offset' ),
218
+ moving_mean_name = bn_name + '_mean' ,
219
+ moving_variance_name = bn_name + '_variance' ,
220
+ name = name + ".add.output.5" )
189
221
190
- return fluid .layers .elementwise_add (
191
- x = short , y = conv2 , act = 'relu' , name = name + ".add.output.5" )
222
+ return short
192
223
193
- def basic_block (self , input , num_filters , stride , is_first , name , data_format ):
224
+ def basic_block (self , input , num_filters , stride , is_first , name ,
225
+ data_format , fuse_bn_add_act ):
194
226
conv0 = self .conv_bn_layer (
195
227
input = input ,
196
228
num_filters = num_filters ,
@@ -199,16 +231,54 @@ def basic_block(self, input, num_filters, stride, is_first, name, data_format):
199
231
stride = stride ,
200
232
name = name + "_branch2a" ,
201
233
data_format = data_format )
202
- conv1 = self .conv_bn_layer (
203
- input = conv0 ,
204
- num_filters = num_filters ,
205
- filter_size = 3 ,
206
- act = None ,
207
- name = name + "_branch2b" ,
208
- data_format = data_format )
209
- short = self .shortcut (
210
- input , num_filters , stride , is_first , name = name + "_branch1" , data_format = data_format )
211
- return fluid .layers .elementwise_add (x = short , y = conv1 , act = 'relu' )
234
+ if not fuse_bn_add_act :
235
+ conv1 = self .conv_bn_layer (
236
+ input = conv0 ,
237
+ num_filters = num_filters ,
238
+ filter_size = 3 ,
239
+ act = None ,
240
+ name = name + "_branch2b" ,
241
+ data_format = data_format )
242
+ short = self .shortcut (
243
+ input ,
244
+ num_filters ,
245
+ stride ,
246
+ is_first ,
247
+ name = name + "_branch1" ,
248
+ data_format = data_format )
249
+
250
+ return fluid .layers .elementwise_add (x = short , y = conv1 , act = 'relu' )
251
+ else :
252
+ conv1 = fluid .layers .conv2d (
253
+ input = conv0 ,
254
+ num_filters = num_filters ,
255
+ filter_size = 3 ,
256
+ stride = 1 ,
257
+ padding = 1 ,
258
+ groups = 1 ,
259
+ act = None ,
260
+ param_attr = ParamAttr (name = name + "_weights" ),
261
+ bias_attr = False ,
262
+ name = name + '_branch2b' + '.conv2d.output.1' ,
263
+ data_format = data_format )
264
+ short = self .shortcut (
265
+ input ,
266
+ num_filters ,
267
+ stride ,
268
+ is_first ,
269
+ name = name + "_branch1" ,
270
+ data_format = data_format )
271
+ name = name + "_branch2b"
272
+ bn_name = "bn" + name [3 :]
273
+ short = fluid .contrib .layers .fused_bn_add_act (
274
+ conv1 ,
275
+ short ,
276
+ param_attr = ParamAttr (name = bn_name + '_scale' ),
277
+ bias_attr = ParamAttr (bn_name + '_offset' ),
278
+ moving_mean_name = bn_name + '_mean' ,
279
+ moving_variance_name = bn_name + '_variance' )
280
+
281
+ return short
212
282
213
283
214
284
def ResNet18 ():
0 commit comments