Skip to content

Commit bc07a01

Browse files
authored
Transfer the value of stop_gradient for feeding data. (#4831)
test=develop
1 parent 12080a0 commit bc07a01

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

PaddleCV/image_classification/build_model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ def _basic_model(data, model, args, is_train):
3636
image = data[0]
3737
label = data[1]
3838
if args.model == "ResNet50":
39-
image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
40-
net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format)
39+
image_in = fluid.layers.transpose(
40+
image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
41+
image_in.stop_gradient = image.stop_gradient
42+
net_out = model.net(input=image_in,
43+
class_dim=args.class_dim,
44+
data_format=args.data_format)
4145
else:
4246
net_out = model.net(input=image, class_dim=args.class_dim)
4347
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
@@ -92,8 +96,12 @@ def _mixup_model(data, model, args, is_train):
9296
lam = data[3]
9397

9498
if args.model == "ResNet50":
95-
image_in = fluid.layers.transpose(image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
96-
net_out = model.net(input=image_in, class_dim=args.class_dim, data_format=args.data_format)
99+
image_in = fluid.layers.transpose(
100+
image, [0, 2, 3, 1]) if args.data_format == 'NHWC' else image
101+
image_in.stop_gradient = image.stop_gradient
102+
net_out = model.net(input=image_in,
103+
class_dim=args.class_dim,
104+
data_format=args.data_format)
97105
else:
98106
net_out = model.net(input=image, class_dim=args.class_dim)
99107
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)

0 commit comments

Comments
 (0)