Skip to content

Commit db6ce5e

Browse files
fix language model time print (#4865)
* fix language_model timecost algorithm * fix dataloader time calc, test=develop
1 parent 295c16b commit db6ce5e

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

PaddleNLP/language_model/train.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@
4949

5050
SEED = 123
5151

52+
class TimeCostAverage(object):
53+
def __init__(self):
54+
self.reset()
55+
def reset(self):
56+
self.cnt = 0
57+
self.total_time = 0
58+
def record(self, usetime):
59+
self.cnt += 1
60+
self.total_time += usetime
61+
def get_average(self):
62+
if self.cnt == 0:
63+
return 0
64+
return self.total_time / self.cnt
5265

5366
@contextlib.contextmanager
5467
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
@@ -293,8 +306,10 @@ def train_an_epoch(epoch_id, batch_times):
293306

294307
total_loss = 0
295308
iters = 0
309+
batch_cost_avg = TimeCostAverage()
296310

297311
init_hidden, init_cell = generate_init_data()
312+
batch_start_time = time.time()
298313
for batch_id, batch in enumerate(train_data_iter):
299314
input_data_feed = prepare_input(
300315
batch,
@@ -303,7 +318,6 @@ def train_an_epoch(epoch_id, batch_times):
303318
epoch_id=epoch_id,
304319
with_lr=True,
305320
device_count=device_count)
306-
batch_start_time = time.time()
307321
fetch_outs = exe.run(train_program,
308322
feed=input_data_feed,
309323
fetch_list=[
@@ -313,6 +327,7 @@ def train_an_epoch(epoch_id, batch_times):
313327
use_program_cache=True)
314328
batch_time = time.time() - batch_start_time
315329
batch_times.append(batch_time)
330+
batch_cost_avg.record(batch_time)
316331

317332
cost_train = np.array(fetch_outs[0])
318333
lr = np.array(fetch_outs[1])
@@ -324,13 +339,17 @@ def train_an_epoch(epoch_id, batch_times):
324339
ppl = np.exp(total_loss / iters)
325340
print(
326341
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
327-
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
342+
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
343+
batch_cost_avg.reset()
328344

329345
# profiler tools for benchmark
330346
if args.profile and batch_id == log_interval:
331347
profiler.reset_profiler()
332348
elif args.profile and batch_id == (log_interval + 5):
333349
break
350+
351+
batch_start_time = time.time()
352+
334353
ppl = np.exp(total_loss / iters)
335354
return ppl
336355

@@ -342,6 +361,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
342361

343362
total_loss = 0
344363
iters = 0
364+
batch_cost_avg = TimeCostAverage()
345365

346366
dataloader.start()
347367
batch_id = 0
@@ -355,6 +375,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
355375
batch_time = time.time() - batch_start_time
356376
batch_times.append(batch_time)
357377
batch_start_time = time.time()
378+
batch_cost_avg.record(batch_time)
358379

359380
new_lr = generate_new_lr(epoch_id, device_count)
360381
data_feeds['learning_rate'] = new_lr
@@ -381,7 +402,8 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
381402
ppl = np.exp(total_loss / iters)
382403
print(
383404
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
384-
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
405+
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
406+
batch_cost_avg.reset()
385407

386408
batch_id += 1
387409
# profiler tools for benchmark

0 commit comments

Comments
 (0)