Skip to content

Commit a234172

Browse files
author
root
committed
pad input to use tensor core
1 parent 294ff30 commit a234172

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

PaddleCV/image_classification/dali.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self,
4343
num_shards=1,
4444
random_shuffle=True,
4545
num_threads=4,
46-
seed=42):
46+
seed=42,
47+
pad_output=False):
4748
super(HybridTrainPipe, self).__init__(
4849
batch_size, num_threads, device_id, seed=seed)
4950
self.input = ops.FileReader(
@@ -73,7 +74,8 @@ def __init__(self,
7374
crop=(crop, crop),
7475
image_type=types.RGB,
7576
mean=mean,
76-
std=std)
77+
std=std,
78+
pad_output=pad_output)
7779
self.coin = ops.CoinFlip(probability=0.5)
7880
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
7981

@@ -104,7 +106,8 @@ def __init__(self,
104106
num_shards=1,
105107
random_shuffle=False,
106108
num_threads=4,
107-
seed=42):
109+
seed=42,
110+
pad_output=False):
108111
super(HybridValPipe, self).__init__(
109112
batch_size, num_threads, device_id, seed=seed)
110113
self.input = ops.FileReader(
@@ -123,7 +126,8 @@ def __init__(self,
123126
crop=(crop, crop),
124127
image_type=types.RGB,
125128
mean=mean,
126-
std=std)
129+
std=std,
130+
pad_output=pad_output)
127131
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
128132

129133
def define_graph(self):
@@ -169,6 +173,9 @@ def build(settings, mode='train'):
169173
}
170174
assert interp in interp_map, "interpolation method not supported by DALI"
171175
interp = interp_map[interp]
176+
pad_output = False
177+
if settings.image_shape[0] == 4:
178+
pad_output = True
172179

173180
if mode != 'train':
174181
p = fluid.framework.cuda_places()[0]
@@ -188,7 +195,8 @@ def build(settings, mode='train'):
188195
interp,
189196
mean,
190197
std,
191-
device_id=device_id)
198+
device_id=device_id,
199+
pad_output=pad_output)
192200
pipe.build()
193201
return DALIGenericIterator(
194202
pipe, ['feed_image', 'feed_label'],
@@ -221,7 +229,8 @@ def build(settings, mode='train'):
221229
device_id,
222230
shard_id,
223231
num_shards,
224-
seed=42 + shard_id)
232+
seed=42 + shard_id,
233+
pad_output=pad_output)
225234
pipe.build()
226235
pipelines = [pipe]
227236
sample_per_shard = len(pipe) // num_shards
@@ -248,7 +257,8 @@ def build(settings, mode='train'):
248257
device_id,
249258
idx,
250259
num_shards,
251-
seed=42 + idx)
260+
seed=42 + idx,
261+
pad_output=pad_output)
252262
pipe.build()
253263
pipelines.append(pipe)
254264
sample_per_shard = len(pipelines[0])

PaddleCV/image_classification/reader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def process_image(sample, settings, mode, color_jitter, rotate):
245245
img_std = np.array(std).reshape((3, 1, 1))
246246
img -= img_mean
247247
img /= img_std
248+
if settings.image_shape[0] == 4:
249+
pad0 = np.zeros((1, img.shape[1], img.shape[2]))
250+
img = np.concatenate((img, pad0), axis=0)
248251
# doing training (train.py)
249252
if mode == 'train' or (mode == 'val' and
250253
not hasattr(settings, 'save_json_path')):

0 commit comments

Comments
 (0)