Skip to content

Commit 4729d65

Browse files
Fixing batch size param in transformer training (#1244)
1 parent 9f89a6f commit 4729d65

File tree

1 file changed

+2
-2
lines changed
  • models/language_translation/tensorflow/transformer_mlperf/training/bfloat16/transformer

1 file changed

+2
-2
lines changed

models/language_translation/tensorflow/transformer_mlperf/training/bfloat16/transformer/transformer_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,6 @@ def train_schedule(
324324
#profile file will be saved in in profile_dir
325325
#Creating hooks for printing Examples per Second, used with estimator.train
326326
training_batch_size = estimator.params.batch_size
327-
if FLAGS.batch_size != -1:
328-
training_batch_size = FLAGS.batch_size
329327
train_hooks = hooks_helper.get_train_hooks(
330328
["ExamplesPerSecondHook"],
331329
model_dir=FLAGS.model_dir,
@@ -429,6 +427,8 @@ def main(_):
429427
params.repeat_dataset = single_iteration_train_epochs
430428
params.horovod = is_mpi
431429
params.static_batch = FLAGS.static_batch
430+
if FLAGS.batch_size != -1:
431+
params.batch_size = FLAGS.batch_size
432432
# Add inter_op and intra_op parallelism thread
433433
session_config = tf.compat.v1.ConfigProto(
434434
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,

0 commit comments

Comments
 (0)