@@ -43,7 +43,8 @@ def __init__(self,
43
43
num_shards = 1 ,
44
44
random_shuffle = True ,
45
45
num_threads = 4 ,
46
- seed = 42 ):
46
+ seed = 42 ,
47
+ pad_output = False ):
47
48
super (HybridTrainPipe , self ).__init__ (
48
49
batch_size , num_threads , device_id , seed = seed )
49
50
self .input = ops .FileReader (
@@ -73,7 +74,8 @@ def __init__(self,
73
74
crop = (crop , crop ),
74
75
image_type = types .RGB ,
75
76
mean = mean ,
76
- std = std )
77
+ std = std ,
78
+ pad_output = pad_output )
77
79
self .coin = ops .CoinFlip (probability = 0.5 )
78
80
self .to_int64 = ops .Cast (dtype = types .INT64 , device = "gpu" )
79
81
@@ -104,7 +106,8 @@ def __init__(self,
104
106
num_shards = 1 ,
105
107
random_shuffle = False ,
106
108
num_threads = 4 ,
107
- seed = 42 ):
109
+ seed = 42 ,
110
+ pad_output = False ):
108
111
super (HybridValPipe , self ).__init__ (
109
112
batch_size , num_threads , device_id , seed = seed )
110
113
self .input = ops .FileReader (
@@ -123,7 +126,8 @@ def __init__(self,
123
126
crop = (crop , crop ),
124
127
image_type = types .RGB ,
125
128
mean = mean ,
126
- std = std )
129
+ std = std ,
130
+ pad_output = pad_output )
127
131
self .to_int64 = ops .Cast (dtype = types .INT64 , device = "gpu" )
128
132
129
133
def define_graph (self ):
@@ -169,6 +173,9 @@ def build(settings, mode='train'):
169
173
}
170
174
assert interp in interp_map , "interpolation method not supported by DALI"
171
175
interp = interp_map [interp ]
176
+ pad_output = False
177
+ if settings .image_shape [0 ] == 4 :
178
+ pad_output = True
172
179
173
180
if mode != 'train' :
174
181
p = fluid .framework .cuda_places ()[0 ]
@@ -188,7 +195,8 @@ def build(settings, mode='train'):
188
195
interp ,
189
196
mean ,
190
197
std ,
191
- device_id = device_id )
198
+ device_id = device_id ,
199
+ pad_output = pad_output )
192
200
pipe .build ()
193
201
return DALIGenericIterator (
194
202
pipe , ['feed_image' , 'feed_label' ],
@@ -221,7 +229,8 @@ def build(settings, mode='train'):
221
229
device_id ,
222
230
shard_id ,
223
231
num_shards ,
224
- seed = 42 + shard_id )
232
+ seed = 42 + shard_id ,
233
+ pad_output = pad_output )
225
234
pipe .build ()
226
235
pipelines = [pipe ]
227
236
sample_per_shard = len (pipe ) // num_shards
@@ -248,7 +257,8 @@ def build(settings, mode='train'):
248
257
device_id ,
249
258
idx ,
250
259
num_shards ,
251
- seed = 42 + idx )
260
+ seed = 42 + idx ,
261
+ pad_output = pad_output )
252
262
pipe .build ()
253
263
pipelines .append (pipe )
254
264
sample_per_shard = len (pipelines [0 ])
0 commit comments