Skip to content

Commit 60fbf64

Browse files
MultiPathfacebook-github-bot
authored andcommitted
Add --eval-bleu for translation
Summary: Pull Request resolved: fairinternal/fairseq-py#989 Reviewed By: MultiPath Differential Revision: D19411162 Pulled By: myleott fbshipit-source-id: 74842f0174f58e39a13fb90f3cc1170c63bc89be
1 parent 122fc1d commit 60fbf64

File tree

5 files changed

+171
-26
lines changed

5 files changed

+171
-26
lines changed

examples/translation/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,13 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train \
116116
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
117117
--dropout 0.3 --weight-decay 0.0001 \
118118
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
119-
--max-tokens 4096
119+
--max-tokens 4096 \
120+
--eval-bleu \
121+
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
122+
--eval-bleu-detok moses \
123+
--eval-bleu-remove-bpe \
124+
--eval-bleu-print-samples \
125+
--best-checkpoint-metric bleu --maximize-best-checkpoint-metric
120126
```
121127

122128
Finally we can evaluate our trained model:

fairseq/meters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def get_smoothed_value(self, key: str) -> float:
230230

231231
def get_smoothed_values(self) -> Dict[str, float]:
232232
"""Get all smoothed values."""
233-
return OrderedDict([(key, self.get_smoothed_value(key)) for key in self.keys()])
233+
return OrderedDict([
234+
(key, self.get_smoothed_value(key))
235+
for key in self.keys()
236+
if not key.startswith("_")
237+
])
234238

235239
def reset(self):
236240
"""Reset Meter instances."""

fairseq/tasks/translation.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,30 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from argparse import Namespace
7+
import json
68
import itertools
79
import logging
810
import os
911

10-
from fairseq import options, utils
12+
import numpy as np
13+
14+
from fairseq import metrics, options, utils
1115
from fairseq.data import (
1216
AppendTokenDataset,
1317
ConcatDataset,
1418
data_utils,
19+
encoders,
1520
indexed_dataset,
1621
LanguagePairDataset,
1722
PrependTokenDataset,
1823
StripTokenDataset,
1924
TruncateDataset,
2025
)
2126

22-
from . import FairseqTask, register_task
27+
from fairseq.tasks import FairseqTask, register_task
28+
29+
EVAL_BLEU_ORDER = 4
2330

2431

2532
logger = logging.getLogger(__name__)
@@ -155,6 +162,26 @@ def add_args(parser):
155162
help='amount to upsample primary dataset')
156163
parser.add_argument('--truncate-source', action='store_true', default=False,
157164
help='truncate source to max-source-positions')
165+
166+
# options for reporting BLEU during validation
167+
parser.add_argument('--eval-bleu', action='store_true',
168+
help='evaluation with BLEU scores')
169+
parser.add_argument('--eval-bleu-detok', type=str, default="space",
170+
help='detokenizer before computing BLEU (e.g., "moses"); '
171+
'required if using --eval-bleu; use "space" to '
172+
'disable detokenization; see fairseq.data.encoders '
173+
'for other options')
174+
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
175+
help='args for building the tokenizer, if needed')
176+
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
177+
help='if setting, we compute tokenized BLEU instead of sacrebleu')
178+
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
179+
help='remove BPE before computing BLEU')
180+
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
181+
help='generation args for BLUE scoring, '
182+
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
183+
parser.add_argument('--eval-bleu-print-samples', action='store_true',
184+
help='print sample generations during validation')
158185
# fmt: on
159186

160187
def __init__(self, args, src_dict, tgt_dict):
@@ -219,6 +246,75 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
219246
def build_dataset_for_inference(self, src_tokens, src_lengths):
220247
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
221248

249+
def build_model(self, args):
250+
if getattr(args, 'eval_bleu', False):
251+
assert getattr(args, 'eval_bleu_detok', None) is not None, (
252+
'--eval-bleu-detok is required if using --eval-bleu; '
253+
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
254+
'to disable detokenization, e.g., when using sentencepiece)'
255+
)
256+
detok_args = json.loads(getattr(args, 'eval_bleu_detok_args', '{}') or '{}')
257+
self.tokenizer = encoders.build_tokenizer(Namespace(
258+
tokenizer=getattr(args, 'eval_bleu_detok', None),
259+
**detok_args
260+
))
261+
262+
gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}')
263+
self.sequence_generator = self.build_generator(Namespace(**gen_args))
264+
return super().build_model(args)
265+
266+
def valid_step(self, sample, model, criterion):
267+
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
268+
if self.args.eval_bleu:
269+
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
270+
logging_output['_bleu_sys_len'] = bleu.sys_len
271+
logging_output['_bleu_ref_len'] = bleu.ref_len
272+
# we split counts into separate entries so that they can be
273+
# summed efficiently across workers using fast-stat-sync
274+
assert len(bleu.counts) == EVAL_BLEU_ORDER
275+
for i in range(EVAL_BLEU_ORDER):
276+
logging_output['_bleu_counts_' + str(i)] = bleu.counts[i]
277+
logging_output['_bleu_totals_' + str(i)] = bleu.totals[i]
278+
return loss, sample_size, logging_output
279+
280+
def reduce_metrics(self, logging_outputs, criterion):
281+
super().reduce_metrics(logging_outputs, criterion)
282+
if self.args.eval_bleu:
283+
284+
def sum_logs(key):
285+
return sum(log.get(key, 0) for log in logging_outputs)
286+
287+
counts, totals = [], []
288+
for i in range(EVAL_BLEU_ORDER):
289+
counts.append(sum_logs('_bleu_counts_' + str(i)))
290+
totals.append(sum_logs('_bleu_totals_' + str(i)))
291+
292+
if max(totals) > 0:
293+
# log counts as numpy arrays -- log_scalar will sum them correctly
294+
metrics.log_scalar('_bleu_counts', np.array(counts))
295+
metrics.log_scalar('_bleu_totals', np.array(totals))
296+
metrics.log_scalar('_bleu_sys_len', sum_logs('_bleu_sys_len'))
297+
metrics.log_scalar('_bleu_ref_len', sum_logs('_bleu_ref_len'))
298+
299+
def compute_bleu(meters):
300+
import inspect
301+
import sacrebleu
302+
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
303+
if 'smooth_method' in fn_sig:
304+
smooth = {'smooth_method': 'exp'}
305+
else:
306+
smooth = {'smooth': 'exp'}
307+
bleu = sacrebleu.compute_bleu(
308+
correct=meters['_bleu_counts'].sum,
309+
total=meters['_bleu_totals'].sum,
310+
sys_len=meters['_bleu_sys_len'].sum,
311+
ref_len=meters['_bleu_ref_len'].sum,
312+
**smooth
313+
)
314+
return round(bleu.score, 2)
315+
316+
metrics.log_derived('bleu', compute_bleu)
317+
222318
def max_positions(self):
223319
"""Return the max sentence length allowed by the task."""
224320
return (self.args.max_source_positions, self.args.max_target_positions)
@@ -232,3 +328,30 @@ def source_dictionary(self):
232328
def target_dictionary(self):
233329
"""Return the target :class:`~fairseq.data.Dictionary`."""
234330
return self.tgt_dict
331+
332+
def _inference_with_bleu(self, generator, sample, model):
333+
import sacrebleu
334+
335+
def decode(toks, escape_unk=False):
336+
s = self.tgt_dict.string(
337+
toks.int().cpu(),
338+
self.args.eval_bleu_remove_bpe,
339+
escape_unk=escape_unk,
340+
)
341+
if self.tokenizer:
342+
s = self.tokenizer.decode(s)
343+
return s
344+
345+
gen_out = self.inference_step(generator, [model], sample, None)
346+
hyps, refs = [], []
347+
for i in range(len(gen_out)):
348+
hyps.append(decode(gen_out[i][0]['tokens']))
349+
refs.append(decode(
350+
utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
351+
escape_unk=True, # don't count <unk> as matches to the hypo
352+
))
353+
if self.args.eval_bleu_print_samples:
354+
logger.info('example hypothesis: ' + hyps[0])
355+
logger.info('example reference: ' + refs[0])
356+
tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args.eval_tokenized_bleu else 'none'
357+
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize)

fairseq_cli/train.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -163,39 +163,38 @@ def train(args, trainer, task, epoch_itr):
163163

164164
valid_subsets = args.valid_subset.split(',')
165165
max_update = args.max_update or math.inf
166-
for samples in progress:
167-
with metrics.aggregate('train_inner'):
166+
with metrics.aggregate() as agg:
167+
for samples in progress:
168168
log_output = trainer.train_step(samples)
169169
num_updates = trainer.get_num_updates()
170170
if log_output is None:
171171
continue
172172

173173
# log mid-epoch stats
174-
stats = get_training_stats('train_inner')
174+
stats = get_training_stats(agg.get_smoothed_values())
175175
progress.log(stats, tag='train', step=num_updates)
176176

177-
if (
178-
not args.disable_validation
179-
and args.save_interval_updates > 0
180-
and num_updates % args.save_interval_updates == 0
181-
and num_updates > 0
182-
):
183-
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
184-
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
177+
if (
178+
not args.disable_validation
179+
and args.save_interval_updates > 0
180+
and num_updates % args.save_interval_updates == 0
181+
and num_updates > 0
182+
):
183+
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
184+
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
185185

186-
if num_updates >= max_update:
187-
break
186+
if num_updates >= max_update:
187+
break
188188

189189
# log end-of-epoch stats
190-
stats = get_training_stats('train')
190+
stats = get_training_stats(agg.get_smoothed_values())
191191
progress.print(stats, tag='train', step=num_updates)
192192

193193
# reset epoch-level meters
194194
metrics.reset_meters('train')
195195

196196

197-
def get_training_stats(stats_key):
198-
stats = metrics.get_smoothed_values(stats_key)
197+
def get_training_stats(stats):
199198
if 'nll_loss' in stats and 'ppl' not in stats:
200199
stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
201200
stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0)
@@ -233,22 +232,22 @@ def validate(args, trainer, task, epoch_itr, subsets):
233232
no_progress_bar='simple'
234233
)
235234

236-
# reset validation loss meters
235+
# reset validation meters
237236
metrics.reset_meters('valid')
238237

239-
for sample in progress:
240-
trainer.valid_step(sample)
238+
with metrics.aggregate() as agg:
239+
for sample in progress:
240+
trainer.valid_step(sample)
241241

242242
# log validation stats
243-
stats = get_valid_stats(args, trainer)
243+
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
244244
progress.print(stats, tag=subset, step=trainer.get_num_updates())
245245

246246
valid_losses.append(stats[args.best_checkpoint_metric])
247247
return valid_losses
248248

249249

250-
def get_valid_stats(args, trainer):
251-
stats = metrics.get_smoothed_values('valid')
250+
def get_valid_stats(args, trainer, stats):
252251
if 'nll_loss' in stats and 'ppl' not in stats:
253252
stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
254253
stats['num_updates'] = trainer.get_num_updates()

tests/test_binaries.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ def test_generation(self):
128128
])
129129
generate_main(data_dir, ['--prefix-size', '2'])
130130

131+
def test_eval_bleu(self):
132+
with contextlib.redirect_stdout(StringIO()):
133+
with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir:
134+
create_dummy_data(data_dir)
135+
preprocess_translation_data(data_dir)
136+
train_translation_model(data_dir, 'fconv_iwslt_de_en', [
137+
'--eval-bleu',
138+
'--eval-bleu-print-samples',
139+
'--eval-bleu-remove-bpe',
140+
'--eval-bleu-detok', 'space',
141+
'--eval-bleu-args', '{"beam": 4, "min_len": 10}',
142+
])
143+
131144
def test_lstm(self):
132145
with contextlib.redirect_stdout(StringIO()):
133146
with tempfile.TemporaryDirectory('test_lstm') as data_dir:

0 commit comments

Comments
 (0)