@@ -36,8 +36,12 @@ def _basic_model(data, model, args, is_train):
36
36
image = data [0 ]
37
37
label = data [1 ]
38
38
if args .model == "ResNet50" :
39
- image_in = fluid .layers .transpose (image , [0 , 2 , 3 , 1 ]) if args .data_format == 'NHWC' else image
40
- net_out = model .net (input = image_in , class_dim = args .class_dim , data_format = args .data_format )
39
+ image_in = fluid .layers .transpose (
40
+ image , [0 , 2 , 3 , 1 ]) if args .data_format == 'NHWC' else image
41
+ image_in .stop_gradient = image .stop_gradient
42
+ net_out = model .net (input = image_in ,
43
+ class_dim = args .class_dim ,
44
+ data_format = args .data_format )
41
45
else :
42
46
net_out = model .net (input = image , class_dim = args .class_dim )
43
47
softmax_out = fluid .layers .softmax (net_out , use_cudnn = False )
@@ -92,8 +96,12 @@ def _mixup_model(data, model, args, is_train):
92
96
lam = data [3 ]
93
97
94
98
if args .model == "ResNet50" :
95
- image_in = fluid .layers .transpose (image , [0 , 2 , 3 , 1 ]) if args .data_format == 'NHWC' else image
96
- net_out = model .net (input = image_in , class_dim = args .class_dim , data_format = args .data_format )
99
+ image_in = fluid .layers .transpose (
100
+ image , [0 , 2 , 3 , 1 ]) if args .data_format == 'NHWC' else image
101
+ image_in .stop_gradient = image .stop_gradient
102
+ net_out = model .net (input = image_in ,
103
+ class_dim = args .class_dim ,
104
+ data_format = args .data_format )
97
105
else :
98
106
net_out = model .net (input = image , class_dim = args .class_dim )
99
107
softmax_out = fluid .layers .softmax (net_out , use_cudnn = False )
0 commit comments