3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ from argparse import Namespace
7
+ import json
6
8
import itertools
7
9
import logging
8
10
import os
9
11
10
- from fairseq import options , utils
12
+ import numpy as np
13
+
14
+ from fairseq import metrics , options , utils
11
15
from fairseq .data import (
12
16
AppendTokenDataset ,
13
17
ConcatDataset ,
14
18
data_utils ,
19
+ encoders ,
15
20
indexed_dataset ,
16
21
LanguagePairDataset ,
17
22
PrependTokenDataset ,
18
23
StripTokenDataset ,
19
24
TruncateDataset ,
20
25
)
21
26
22
- from . import FairseqTask , register_task
27
+ from fairseq .tasks import FairseqTask , register_task
28
+
29
+ EVAL_BLEU_ORDER = 4
23
30
24
31
25
32
logger = logging .getLogger (__name__ )
@@ -155,6 +162,26 @@ def add_args(parser):
155
162
help = 'amount to upsample primary dataset' )
156
163
parser .add_argument ('--truncate-source' , action = 'store_true' , default = False ,
157
164
help = 'truncate source to max-source-positions' )
165
+
166
+ # options for reporting BLEU during validation
167
+ parser .add_argument ('--eval-bleu' , action = 'store_true' ,
168
+ help = 'evaluation with BLEU scores' )
169
+ parser .add_argument ('--eval-bleu-detok' , type = str , default = "space" ,
170
+ help = 'detokenizer before computing BLEU (e.g., "moses"); '
171
+ 'required if using --eval-bleu; use "space" to '
172
+ 'disable detokenization; see fairseq.data.encoders '
173
+ 'for other options' )
174
+ parser .add_argument ('--eval-bleu-detok-args' , type = str , metavar = 'JSON' ,
175
+ help = 'args for building the tokenizer, if needed' )
176
+ parser .add_argument ('--eval-tokenized-bleu' , action = 'store_true' , default = False ,
177
+ help = 'if setting, we compute tokenized BLEU instead of sacrebleu' )
178
+ parser .add_argument ('--eval-bleu-remove-bpe' , nargs = '?' , const = '@@ ' , default = None ,
179
+ help = 'remove BPE before computing BLEU' )
180
+ parser .add_argument ('--eval-bleu-args' , type = str , metavar = 'JSON' ,
181
+ help = 'generation args for BLUE scoring, '
182
+ 'e.g., \' {"beam": 4, "lenpen": 0.6}\' ' )
183
+ parser .add_argument ('--eval-bleu-print-samples' , action = 'store_true' ,
184
+ help = 'print sample generations during validation' )
158
185
# fmt: on
159
186
160
187
def __init__ (self , args , src_dict , tgt_dict ):
@@ -219,6 +246,75 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
219
246
def build_dataset_for_inference (self , src_tokens , src_lengths ):
220
247
return LanguagePairDataset (src_tokens , src_lengths , self .source_dictionary )
221
248
249
+ def build_model (self , args ):
250
+ if getattr (args , 'eval_bleu' , False ):
251
+ assert getattr (args , 'eval_bleu_detok' , None ) is not None , (
252
+ '--eval-bleu-detok is required if using --eval-bleu; '
253
+ 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
254
+ 'to disable detokenization, e.g., when using sentencepiece)'
255
+ )
256
+ detok_args = json .loads (getattr (args , 'eval_bleu_detok_args' , '{}' ) or '{}' )
257
+ self .tokenizer = encoders .build_tokenizer (Namespace (
258
+ tokenizer = getattr (args , 'eval_bleu_detok' , None ),
259
+ ** detok_args
260
+ ))
261
+
262
+ gen_args = json .loads (getattr (args , 'eval_bleu_args' , '{}' ) or '{}' )
263
+ self .sequence_generator = self .build_generator (Namespace (** gen_args ))
264
+ return super ().build_model (args )
265
+
266
+ def valid_step (self , sample , model , criterion ):
267
+ loss , sample_size , logging_output = super ().valid_step (sample , model , criterion )
268
+ if self .args .eval_bleu :
269
+ bleu = self ._inference_with_bleu (self .sequence_generator , sample , model )
270
+ logging_output ['_bleu_sys_len' ] = bleu .sys_len
271
+ logging_output ['_bleu_ref_len' ] = bleu .ref_len
272
+ # we split counts into separate entries so that they can be
273
+ # summed efficiently across workers using fast-stat-sync
274
+ assert len (bleu .counts ) == EVAL_BLEU_ORDER
275
+ for i in range (EVAL_BLEU_ORDER ):
276
+ logging_output ['_bleu_counts_' + str (i )] = bleu .counts [i ]
277
+ logging_output ['_bleu_totals_' + str (i )] = bleu .totals [i ]
278
+ return loss , sample_size , logging_output
279
+
280
+ def reduce_metrics (self , logging_outputs , criterion ):
281
+ super ().reduce_metrics (logging_outputs , criterion )
282
+ if self .args .eval_bleu :
283
+
284
+ def sum_logs (key ):
285
+ return sum (log .get (key , 0 ) for log in logging_outputs )
286
+
287
+ counts , totals = [], []
288
+ for i in range (EVAL_BLEU_ORDER ):
289
+ counts .append (sum_logs ('_bleu_counts_' + str (i )))
290
+ totals .append (sum_logs ('_bleu_totals_' + str (i )))
291
+
292
+ if max (totals ) > 0 :
293
+ # log counts as numpy arrays -- log_scalar will sum them correctly
294
+ metrics .log_scalar ('_bleu_counts' , np .array (counts ))
295
+ metrics .log_scalar ('_bleu_totals' , np .array (totals ))
296
+ metrics .log_scalar ('_bleu_sys_len' , sum_logs ('_bleu_sys_len' ))
297
+ metrics .log_scalar ('_bleu_ref_len' , sum_logs ('_bleu_ref_len' ))
298
+
299
+ def compute_bleu (meters ):
300
+ import inspect
301
+ import sacrebleu
302
+ fn_sig = inspect .getfullargspec (sacrebleu .compute_bleu )[0 ]
303
+ if 'smooth_method' in fn_sig :
304
+ smooth = {'smooth_method' : 'exp' }
305
+ else :
306
+ smooth = {'smooth' : 'exp' }
307
+ bleu = sacrebleu .compute_bleu (
308
+ correct = meters ['_bleu_counts' ].sum ,
309
+ total = meters ['_bleu_totals' ].sum ,
310
+ sys_len = meters ['_bleu_sys_len' ].sum ,
311
+ ref_len = meters ['_bleu_ref_len' ].sum ,
312
+ ** smooth
313
+ )
314
+ return round (bleu .score , 2 )
315
+
316
+ metrics .log_derived ('bleu' , compute_bleu )
317
+
222
318
def max_positions (self ):
223
319
"""Return the max sentence length allowed by the task."""
224
320
return (self .args .max_source_positions , self .args .max_target_positions )
@@ -232,3 +328,30 @@ def source_dictionary(self):
232
328
def target_dictionary (self ):
233
329
"""Return the target :class:`~fairseq.data.Dictionary`."""
234
330
return self .tgt_dict
331
+
332
+ def _inference_with_bleu (self , generator , sample , model ):
333
+ import sacrebleu
334
+
335
+ def decode (toks , escape_unk = False ):
336
+ s = self .tgt_dict .string (
337
+ toks .int ().cpu (),
338
+ self .args .eval_bleu_remove_bpe ,
339
+ escape_unk = escape_unk ,
340
+ )
341
+ if self .tokenizer :
342
+ s = self .tokenizer .decode (s )
343
+ return s
344
+
345
+ gen_out = self .inference_step (generator , [model ], sample , None )
346
+ hyps , refs = [], []
347
+ for i in range (len (gen_out )):
348
+ hyps .append (decode (gen_out [i ][0 ]['tokens' ]))
349
+ refs .append (decode (
350
+ utils .strip_pad (sample ['target' ][i ], self .tgt_dict .pad ()),
351
+ escape_unk = True , # don't count <unk> as matches to the hypo
352
+ ))
353
+ if self .args .eval_bleu_print_samples :
354
+ logger .info ('example hypothesis: ' + hyps [0 ])
355
+ logger .info ('example reference: ' + refs [0 ])
356
+ tokenize = sacrebleu .DEFAULT_TOKENIZER if not self .args .eval_tokenized_bleu else 'none'
357
+ return sacrebleu .corpus_bleu (hyps , [refs ], tokenize = tokenize )
0 commit comments