24
24
25
25
from .event_handler import MetricHandler , ValidationHandler , LoggingHandler , StoppingHandler
26
26
from .event_handler import TrainBegin , EpochBegin , BatchBegin , BatchEnd , EpochEnd , TrainEnd
27
- from .utils import _check_metrics
27
+ from .event_handler import _check_event_handlers
28
+ from .utils import _check_metrics , _suggest_metric_for_loss , _check_handler_metric_ref
28
29
from ...data import DataLoader
29
- from ...loss import SoftmaxCrossEntropyLoss
30
30
from ...loss import Loss as gluon_loss
31
31
from ...trainer import Trainer
32
32
from ...utils import split_and_load
33
33
from .... import autograd
34
34
from ....context import Context , cpu , gpu , num_gpus
35
- from ....metric import Accuracy
36
35
from ....metric import Loss as metric_loss
37
36
38
37
__all__ = ['Estimator' ]
@@ -48,8 +47,8 @@ class Estimator(object):
48
47
----------
49
48
net : gluon.Block
50
49
The model used for training.
51
- loss : gluon.loss.Loss or list of gluon.loss.Loss
52
- Loss(objective functions) to calculate during training.
50
+ loss : gluon.loss.Loss
51
+ Loss (objective) function to calculate during training.
53
52
metrics : EvalMetric or list of EvalMetric
54
53
Metrics for evaluating models.
55
54
initializer : Initializer
@@ -69,19 +68,17 @@ def __init__(self, net,
69
68
70
69
self .net = net
71
70
self .loss = self ._check_loss (loss )
72
- self .train_metrics = _check_metrics (metrics )
71
+ self ._train_metrics = _check_metrics (metrics )
72
+ self ._add_default_training_metrics ()
73
+ self ._add_validation_metrics ()
73
74
74
75
self .context = self ._check_context (context )
75
76
self ._initialize (initializer )
76
77
self .trainer = self ._check_trainer (trainer )
77
78
78
79
def _check_loss (self , loss ):
79
- if isinstance (loss , gluon_loss ):
80
- loss = [loss ]
81
- elif isinstance (loss , list ) and all ([isinstance (l , gluon_loss ) for l in loss ]):
82
- loss = loss
83
- else :
84
- raise ValueError ("loss must be a Loss or a list of Loss, "
80
+ if not isinstance (loss , gluon_loss ):
81
+ raise ValueError ("loss must be a Loss, "
85
82
"refer to gluon.loss.Loss:{}" .format (loss ))
86
83
return loss
87
84
@@ -166,31 +163,30 @@ def _get_data_and_label(self, batch, ctx, batch_axis=0):
166
163
label = split_and_load (label , ctx_list = ctx , batch_axis = batch_axis )
167
164
return data , label
168
165
169
- def prepare_loss_and_metrics (self ):
170
- """
171
- Based on loss functions and training metrics in estimator
172
- Create metric wrappers to record loss values,
173
- Create copies of train loss/metric objects to record validation values
166
+ def _add_default_training_metrics (self ):
167
+ if not self ._train_metrics :
168
+ suggested_metric = _suggest_metric_for_loss (self .loss )
169
+ if suggested_metric :
170
+ self ._train_metrics = [suggested_metric ]
171
+ loss_name = self .loss .name .rstrip ('1234567890' )
172
+ self ._train_metrics .append (metric_loss (loss_name ))
174
173
175
- Returns
176
- -------
177
- train_metrics, val_metrics
178
- """
179
- if any (not hasattr (self , attribute ) for attribute in
180
- ['train_metrics' , 'val_metrics' ]):
181
- # Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss()
182
- if not self .train_metrics and any ([isinstance (l , SoftmaxCrossEntropyLoss ) for l in self .loss ]):
183
- self .train_metrics = [Accuracy ()]
184
- self .val_metrics = []
185
- for loss in self .loss :
186
- # remove trailing numbers from loss name to avoid confusion
187
- self .train_metrics .append (metric_loss (loss .name .rstrip ('1234567890' )))
188
- for metric in self .train_metrics :
189
- val_metric = copy .deepcopy (metric )
190
- metric .name = "train " + metric .name
191
- val_metric .name = "validation " + val_metric .name
192
- self .val_metrics .append (val_metric )
193
- return self .train_metrics , self .val_metrics
174
+ for metric in self ._train_metrics :
175
+ metric .name = "training " + metric .name
176
+
177
+ def _add_validation_metrics (self ):
178
+ self ._val_metrics = [copy .deepcopy (metric ) for metric in self ._train_metrics ]
179
+
180
+ for metric in self ._val_metrics :
181
+ metric .name = "validation " + metric .name
182
+
183
+ @property
184
+ def train_metrics (self ):
185
+ return self ._train_metrics
186
+
187
+ @property
188
+ def val_metrics (self ):
189
+ return self ._val_metrics
194
190
195
191
def evaluate_batch (self ,
196
192
val_batch ,
@@ -209,7 +205,7 @@ def evaluate_batch(self,
209
205
"""
210
206
data , label = self ._get_data_and_label (val_batch , self .context , batch_axis )
211
207
pred = [self .net (x ) for x in data ]
212
- loss = [self .loss [ 0 ] (y_hat , y ) for y_hat , y in zip (pred , label )]
208
+ loss = [self .loss (y_hat , y ) for y_hat , y in zip (pred , label )]
213
209
# update metrics
214
210
for metric in val_metrics :
215
211
if isinstance (metric , metric_loss ):
@@ -275,7 +271,7 @@ def fit_batch(self, train_batch,
275
271
276
272
with autograd .record ():
277
273
pred = [self .net (x ) for x in data ]
278
- loss = [self .loss [ 0 ] (y_hat , y ) for y_hat , y in zip (pred , label )]
274
+ loss = [self .loss (y_hat , y ) for y_hat , y in zip (pred , label )]
279
275
280
276
for l in loss :
281
277
l .backward ()
@@ -377,63 +373,47 @@ def fit(self, train_data,
377
373
handler .train_end (estimator_ref )
378
374
379
375
def _prepare_default_handlers (self , val_data , event_handlers ):
380
- event_handlers = event_handlers or []
381
- default_handlers = []
382
- self .prepare_loss_and_metrics ()
376
+ event_handlers = _check_event_handlers (event_handlers )
377
+ added_default_handlers = []
383
378
384
379
# no need to add to default handler check as StoppingHandler does not use metrics
385
- event_handlers .append (StoppingHandler (self .max_epoch , self .max_batch ))
386
- default_handlers .append ("StoppingHandler" )
380
+ added_default_handlers .append (StoppingHandler (self .max_epoch , self .max_batch ))
387
381
388
382
if not any (isinstance (handler , MetricHandler ) for handler in event_handlers ):
389
- event_handlers .append (MetricHandler (train_metrics = self .train_metrics ))
390
- default_handlers .append ("MetricHandler" )
383
+ added_default_handlers .append (MetricHandler (train_metrics = self .train_metrics ))
391
384
392
385
if not any (isinstance (handler , ValidationHandler ) for handler in event_handlers ):
393
386
# no validation handler
394
387
if val_data :
395
- # add default validation handler if validation data found
396
- event_handlers .append (ValidationHandler (val_data = val_data , eval_fn = self .evaluate ,
397
- val_metrics = self .val_metrics ))
398
- default_handlers .append ("ValidationHandler" )
399
388
val_metrics = self .val_metrics
389
+ # add default validation handler if validation data found
390
+ added_default_handlers .append (ValidationHandler (val_data = val_data ,
391
+ eval_fn = self .evaluate ,
392
+ val_metrics = val_metrics ))
400
393
else :
401
394
# set validation metrics to None if no validation data and no validation handler
402
395
val_metrics = []
403
396
404
397
if not any (isinstance (handler , LoggingHandler ) for handler in event_handlers ):
405
- event_handlers .append (LoggingHandler (train_metrics = self .train_metrics ,
406
- val_metrics = val_metrics ))
407
- default_handlers .append ("LoggingHandler" )
398
+ added_default_handlers .append (LoggingHandler (train_metrics = self .train_metrics ,
399
+ val_metrics = val_metrics ))
408
400
409
401
# if there is a mix of user defined event handlers and default event handlers
410
- # they should have the same set of loss and metrics
411
- if default_handlers and len (event_handlers ) != len (default_handlers ):
412
- msg = "You are training with the following default event handlers: %s. " \
413
- "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
414
- "Please use the same set of metrics for all your other handlers." % \
415
- ", " .join (default_handlers )
402
+ # they should have the same set of metrics
403
+ mixing_handlers = event_handlers and added_default_handlers
404
+
405
+ event_handlers .extend (added_default_handlers )
406
+
407
+ if mixing_handlers :
408
+ msg = "The following default event handlers are added: {}." .format (
409
+ ", " .join ([type (h ).__name__ for h in added_default_handlers ]))
416
410
warnings .warn (msg )
417
- # check if all handlers has the same set of references to loss and metrics
418
- references = []
411
+
412
+
413
+ # check if all handlers have the same set of references to metrics
414
+ known_metrics = set (self .train_metrics + self .val_metrics )
419
415
for handler in event_handlers :
420
- for attribute in dir (handler ):
421
- if any (keyword in attribute for keyword in ['metric' or 'monitor' ]):
422
- reference = getattr (handler , attribute )
423
- if isinstance (reference , list ):
424
- references += reference
425
- else :
426
- references .append (reference )
427
- # remove None metric references
428
- references = set ([ref for ref in references if ref ])
429
- for metric in references :
430
- if metric not in self .train_metrics + self .val_metrics :
431
- msg = "We have added following default handlers for you: %s and used " \
432
- "estimator.prepare_loss_and_metrics() to pass metrics to " \
433
- "those handlers. Please use the same set of metrics " \
434
- "for all your handlers." % \
435
- ", " .join (default_handlers )
436
- raise ValueError (msg )
416
+ _check_handler_metric_ref (handler , known_metrics )
437
417
438
418
event_handlers .sort (key = lambda handler : getattr (handler , 'priority' , 0 ))
439
419
return event_handlers
0 commit comments