49
49
50
50
SEED = 123
51
51
52
+ class TimeCostAverage (object ):
53
+ def __init__ (self ):
54
+ self .reset ()
55
+ def reset (self ):
56
+ self .cnt = 0
57
+ self .total_time = 0
58
+ def record (self , usetime ):
59
+ self .cnt += 1
60
+ self .total_time += usetime
61
+ def get_average (self ):
62
+ if self .cnt == 0 :
63
+ return 0
64
+ return self .total_time / self .cnt
52
65
53
66
@contextlib .contextmanager
54
67
def profile_context (profile = True , profiler_path = '/tmp/paddingrnn.profile' ):
@@ -293,8 +306,10 @@ def train_an_epoch(epoch_id, batch_times):
293
306
294
307
total_loss = 0
295
308
iters = 0
309
+ batch_cost_avg = TimeCostAverage ()
296
310
297
311
init_hidden , init_cell = generate_init_data ()
312
+ batch_start_time = time .time ()
298
313
for batch_id , batch in enumerate (train_data_iter ):
299
314
input_data_feed = prepare_input (
300
315
batch ,
@@ -303,7 +318,6 @@ def train_an_epoch(epoch_id, batch_times):
303
318
epoch_id = epoch_id ,
304
319
with_lr = True ,
305
320
device_count = device_count )
306
- batch_start_time = time .time ()
307
321
fetch_outs = exe .run (train_program ,
308
322
feed = input_data_feed ,
309
323
fetch_list = [
@@ -313,6 +327,7 @@ def train_an_epoch(epoch_id, batch_times):
313
327
use_program_cache = True )
314
328
batch_time = time .time () - batch_start_time
315
329
batch_times .append (batch_time )
330
+ batch_cost_avg .record (batch_time )
316
331
317
332
cost_train = np .array (fetch_outs [0 ])
318
333
lr = np .array (fetch_outs [1 ])
@@ -324,13 +339,17 @@ def train_an_epoch(epoch_id, batch_times):
324
339
ppl = np .exp (total_loss / iters )
325
340
print (
326
341
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
327
- % (epoch_id , batch_id , batch_time , ppl [0 ], lr [0 ]))
342
+ % (epoch_id , batch_id , batch_cost_avg .get_average (), ppl [0 ], lr [0 ]))
343
+ batch_cost_avg .reset ()
328
344
329
345
# profiler tools for benchmark
330
346
if args .profile and batch_id == log_interval :
331
347
profiler .reset_profiler ()
332
348
elif args .profile and batch_id == (log_interval + 5 ):
333
349
break
350
+
351
+ batch_start_time = time .time ()
352
+
334
353
ppl = np .exp (total_loss / iters )
335
354
return ppl
336
355
@@ -342,6 +361,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
342
361
343
362
total_loss = 0
344
363
iters = 0
364
+ batch_cost_avg = TimeCostAverage ()
345
365
346
366
dataloader .start ()
347
367
batch_id = 0
@@ -355,6 +375,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
355
375
batch_time = time .time () - batch_start_time
356
376
batch_times .append (batch_time )
357
377
batch_start_time = time .time ()
378
+ batch_cost_avg .record (batch_time )
358
379
359
380
new_lr = generate_new_lr (epoch_id , device_count )
360
381
data_feeds ['learning_rate' ] = new_lr
@@ -381,7 +402,8 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
381
402
ppl = np .exp (total_loss / iters )
382
403
print (
383
404
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
384
- % (epoch_id , batch_id , batch_time , ppl [0 ], lr [0 ]))
405
+ % (epoch_id , batch_id , batch_cost_avg .get_average (), ppl [0 ], lr [0 ]))
406
+ batch_cost_avg .reset ()
385
407
386
408
batch_id += 1
387
409
# profiler tools for benchmark
0 commit comments