Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9f6070f

Browse files
szharoywei
authored andcommitted
[Estimator] refactor estimator and clarify docs (#16694)
* refactor estimator and clarify docs * fix info message and test * clean up after releasing logging handler
1 parent f3c6be5 commit 9f6070f

File tree

5 files changed

+139
-107
lines changed

5 files changed

+139
-107
lines changed

python/mxnet/gluon/contrib/estimator/estimator.py

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,14 @@
2424

2525
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
2626
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
27-
from .utils import _check_metrics
27+
from .event_handler import _check_event_handlers
28+
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
2829
from ...data import DataLoader
29-
from ...loss import SoftmaxCrossEntropyLoss
3030
from ...loss import Loss as gluon_loss
3131
from ...trainer import Trainer
3232
from ...utils import split_and_load
3333
from .... import autograd
3434
from ....context import Context, cpu, gpu, num_gpus
35-
from ....metric import Accuracy
3635
from ....metric import Loss as metric_loss
3736

3837
__all__ = ['Estimator']
@@ -48,8 +47,8 @@ class Estimator(object):
4847
----------
4948
net : gluon.Block
5049
The model used for training.
51-
loss : gluon.loss.Loss or list of gluon.loss.Loss
52-
Loss(objective functions) to calculate during training.
50+
loss : gluon.loss.Loss
51+
Loss (objective) function to calculate during training.
5352
metrics : EvalMetric or list of EvalMetric
5453
Metrics for evaluating models.
5554
initializer : Initializer
@@ -69,19 +68,17 @@ def __init__(self, net,
6968

7069
self.net = net
7170
self.loss = self._check_loss(loss)
72-
self.train_metrics = _check_metrics(metrics)
71+
self._train_metrics = _check_metrics(metrics)
72+
self._add_default_training_metrics()
73+
self._add_validation_metrics()
7374

7475
self.context = self._check_context(context)
7576
self._initialize(initializer)
7677
self.trainer = self._check_trainer(trainer)
7778

7879
def _check_loss(self, loss):
79-
if isinstance(loss, gluon_loss):
80-
loss = [loss]
81-
elif isinstance(loss, list) and all([isinstance(l, gluon_loss) for l in loss]):
82-
loss = loss
83-
else:
84-
raise ValueError("loss must be a Loss or a list of Loss, "
80+
if not isinstance(loss, gluon_loss):
81+
raise ValueError("loss must be a Loss, "
8582
"refer to gluon.loss.Loss:{}".format(loss))
8683
return loss
8784

@@ -166,31 +163,30 @@ def _get_data_and_label(self, batch, ctx, batch_axis=0):
166163
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
167164
return data, label
168165

169-
def prepare_loss_and_metrics(self):
170-
"""
171-
Based on loss functions and training metrics in estimator
172-
Create metric wrappers to record loss values,
173-
Create copies of train loss/metric objects to record validation values
166+
def _add_default_training_metrics(self):
167+
if not self._train_metrics:
168+
suggested_metric = _suggest_metric_for_loss(self.loss)
169+
if suggested_metric:
170+
self._train_metrics = [suggested_metric]
171+
loss_name = self.loss.name.rstrip('1234567890')
172+
self._train_metrics.append(metric_loss(loss_name))
174173

175-
Returns
176-
-------
177-
train_metrics, val_metrics
178-
"""
179-
if any(not hasattr(self, attribute) for attribute in
180-
['train_metrics', 'val_metrics']):
181-
# Use default mx.metric.Accuracy() for SoftmaxCrossEntropyLoss()
182-
if not self.train_metrics and any([isinstance(l, SoftmaxCrossEntropyLoss) for l in self.loss]):
183-
self.train_metrics = [Accuracy()]
184-
self.val_metrics = []
185-
for loss in self.loss:
186-
# remove trailing numbers from loss name to avoid confusion
187-
self.train_metrics.append(metric_loss(loss.name.rstrip('1234567890')))
188-
for metric in self.train_metrics:
189-
val_metric = copy.deepcopy(metric)
190-
metric.name = "train " + metric.name
191-
val_metric.name = "validation " + val_metric.name
192-
self.val_metrics.append(val_metric)
193-
return self.train_metrics, self.val_metrics
174+
for metric in self._train_metrics:
175+
metric.name = "training " + metric.name
176+
177+
def _add_validation_metrics(self):
178+
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
179+
180+
for metric in self._val_metrics:
181+
metric.name = "validation " + metric.name
182+
183+
@property
184+
def train_metrics(self):
185+
return self._train_metrics
186+
187+
@property
188+
def val_metrics(self):
189+
return self._val_metrics
194190

195191
def evaluate_batch(self,
196192
val_batch,
@@ -209,7 +205,7 @@ def evaluate_batch(self,
209205
"""
210206
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
211207
pred = [self.net(x) for x in data]
212-
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
208+
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
213209
# update metrics
214210
for metric in val_metrics:
215211
if isinstance(metric, metric_loss):
@@ -275,7 +271,7 @@ def fit_batch(self, train_batch,
275271

276272
with autograd.record():
277273
pred = [self.net(x) for x in data]
278-
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
274+
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
279275

280276
for l in loss:
281277
l.backward()
@@ -377,63 +373,47 @@ def fit(self, train_data,
377373
handler.train_end(estimator_ref)
378374

379375
def _prepare_default_handlers(self, val_data, event_handlers):
380-
event_handlers = event_handlers or []
381-
default_handlers = []
382-
self.prepare_loss_and_metrics()
376+
event_handlers = _check_event_handlers(event_handlers)
377+
added_default_handlers = []
383378

384379
# no need to add to default handler check as StoppingHandler does not use metrics
385-
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
386-
default_handlers.append("StoppingHandler")
380+
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
387381

388382
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
389-
event_handlers.append(MetricHandler(train_metrics=self.train_metrics))
390-
default_handlers.append("MetricHandler")
383+
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
391384

392385
if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
393386
# no validation handler
394387
if val_data:
395-
# add default validation handler if validation data found
396-
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
397-
val_metrics=self.val_metrics))
398-
default_handlers.append("ValidationHandler")
399388
val_metrics = self.val_metrics
389+
# add default validation handler if validation data found
390+
added_default_handlers.append(ValidationHandler(val_data=val_data,
391+
eval_fn=self.evaluate,
392+
val_metrics=val_metrics))
400393
else:
401394
# set validation metrics to None if no validation data and no validation handler
402395
val_metrics = []
403396

404397
if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
405-
event_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
406-
val_metrics=val_metrics))
407-
default_handlers.append("LoggingHandler")
398+
added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
399+
val_metrics=val_metrics))
408400

409401
# if there is a mix of user defined event handlers and default event handlers
410-
# they should have the same set of loss and metrics
411-
if default_handlers and len(event_handlers) != len(default_handlers):
412-
msg = "You are training with the following default event handlers: %s. " \
413-
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
414-
"Please use the same set of metrics for all your other handlers." % \
415-
", ".join(default_handlers)
402+
# they should have the same set of metrics
403+
mixing_handlers = event_handlers and added_default_handlers
404+
405+
event_handlers.extend(added_default_handlers)
406+
407+
if mixing_handlers:
408+
msg = "The following default event handlers are added: {}.".format(
409+
", ".join([type(h).__name__ for h in added_default_handlers]))
416410
warnings.warn(msg)
417-
# check if all handlers has the same set of references to loss and metrics
418-
references = []
411+
412+
413+
# check if all handlers have the same set of references to metrics
414+
known_metrics = set(self.train_metrics + self.val_metrics)
419415
for handler in event_handlers:
420-
for attribute in dir(handler):
421-
if any(keyword in attribute for keyword in ['metric' or 'monitor']):
422-
reference = getattr(handler, attribute)
423-
if isinstance(reference, list):
424-
references += reference
425-
else:
426-
references.append(reference)
427-
# remove None metric references
428-
references = set([ref for ref in references if ref])
429-
for metric in references:
430-
if metric not in self.train_metrics + self.val_metrics:
431-
msg = "We have added following default handlers for you: %s and used " \
432-
"estimator.prepare_loss_and_metrics() to pass metrics to " \
433-
"those handlers. Please use the same set of metrics " \
434-
"for all your handlers." % \
435-
", ".join(default_handlers)
436-
raise ValueError(msg)
416+
_check_handler_metric_ref(handler, known_metrics)
437417

438418
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
439419
return event_handlers

python/mxnet/gluon/contrib/estimator/event_handler.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
# coding: utf-8
19-
# pylint: disable=wildcard-import, unused-argument
19+
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
2020
"""Gluon EventHandlers for Estimators"""
2121

2222
import logging
@@ -34,33 +34,47 @@
3434
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
3535
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']
3636

37+
class EventHandler(object):
38+
pass
3739

38-
class TrainBegin(object):
40+
41+
def _check_event_handlers(handlers):
42+
if isinstance(handlers, EventHandler):
43+
handlers = [handlers]
44+
else:
45+
handlers = handlers or []
46+
if not all([isinstance(handler, EventHandler) for handler in handlers]):
47+
raise ValueError("handlers must be an EventHandler or a list of EventHandler, "
48+
"got: {}".format(handlers))
49+
return handlers
50+
51+
52+
class TrainBegin(EventHandler):
3953
def train_begin(self, estimator, *args, **kwargs):
4054
pass
4155

4256

43-
class TrainEnd(object):
57+
class TrainEnd(EventHandler):
4458
def train_end(self, estimator, *args, **kwargs):
4559
pass
4660

4761

48-
class EpochBegin(object):
62+
class EpochBegin(EventHandler):
4963
def epoch_begin(self, estimator, *args, **kwargs):
5064
pass
5165

5266

53-
class EpochEnd(object):
67+
class EpochEnd(EventHandler):
5468
def epoch_end(self, estimator, *args, **kwargs):
5569
return False
5670

5771

58-
class BatchBegin(object):
72+
class BatchBegin(EventHandler):
5973
def batch_begin(self, estimator, *args, **kwargs):
6074
pass
6175

6276

63-
class BatchEnd(object):
77+
class BatchEnd(EventHandler):
6478
def batch_end(self, estimator, *args, **kwargs):
6579
return False
6680

@@ -242,14 +256,16 @@ def __init__(self, file_name=None,
242256
super(LoggingHandler, self).__init__()
243257
self.logger = logging.getLogger(__name__)
244258
self.logger.setLevel(logging.INFO)
245-
stream_handler = logging.StreamHandler()
246-
self.logger.addHandler(stream_handler)
259+
self._added_logging_handlers = [logging.StreamHandler()]
247260
# save logger to file only if file name or location is specified
248261
if file_name or file_location:
249262
file_name = file_name or 'estimator_log'
250263
file_location = file_location or './'
251264
file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode)
252-
self.logger.addHandler(file_handler)
265+
self._added_logging_handlers.append(file_handler)
266+
for handler in self._added_logging_handlers:
267+
self.logger.addHandler(handler)
268+
253269
if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
254270
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
255271
"LOG_PER_BATCH, received %s. "
@@ -265,6 +281,12 @@ def __init__(self, file_name=None,
265281
# it will also shut down logging at train end
266282
self.priority = np.Inf
267283

284+
def __del__(self):
285+
for handler in self._added_logging_handlers:
286+
handler.flush()
287+
self.logger.removeHandler(handler)
288+
handler.close()
289+
268290
def train_begin(self, estimator, *args, **kwargs):
269291
self.train_start = time.time()
270292
trainer = estimator.trainer
@@ -393,8 +415,8 @@ def __init__(self,
393415
self.model_prefix = model_prefix
394416
self.save_best = save_best
395417
if self.save_best and not isinstance(self.monitor, EvalMetric):
396-
raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
397-
"You can get these objects using estimator.prepare_loss_and_metric()")
418+
raise ValueError("To save best model only, please provide one of the metric objects "
419+
"from estimator.train_metrics and estimator.val_metrics as monitor.")
398420
self.epoch_period = epoch_period
399421
self.batch_period = batch_period
400422
self.current_batch = 0
@@ -487,10 +509,10 @@ def _save_checkpoint(self, estimator):
487509
monitor_name, monitor_value = self.monitor.get()
488510
# check if monitor exists in train stats
489511
if np.isnan(monitor_value):
490-
warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you '
491-
'pass one of the metric objects as monitor, '
492-
'you can use estimator.prepare_loss_and_metrics to'
493-
'create all metric objects', monitor_name))
512+
warnings.warn(RuntimeWarning(
513+
'Skipping save best because %s is not updated, make sure you pass one of the '
514+
'metric objects estimator.train_metrics and estimator.val_metrics as monitor',
515+
monitor_name))
494516
else:
495517
if self.monitor_op(monitor_value, self.best):
496518
prefix = self.model_prefix + '-best'
@@ -517,7 +539,7 @@ def _save_symbol(self, estimator):
517539
sym.save(symbol_file)
518540
else:
519541
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock "
520-
"to construct your model, can call net.hybridize() before passing to "
542+
"to construct your model, and call net.hybridize() before passing to "
521543
"Estimator in order to save model architecture as %s.", symbol_file)
522544

523545
def _save_params_and_trainer(self, estimator, file_prefix):
@@ -636,8 +658,9 @@ def __init__(self,
636658
super(EarlyStoppingHandler, self).__init__()
637659

638660
if not isinstance(monitor, EvalMetric):
639-
raise ValueError("Please provide one of the metric objects as monitor, "
640-
"You can create these objects using estimator.prepare_loss_and_metric()")
661+
raise ValueError(
662+
"Please provide one of the metric objects from estimator.train_metrics and "
663+
"estimator.val_metrics as monitor.")
641664
if isinstance(monitor, CompositeEvalMetric):
642665
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
643666
"please specify a simple metric instead.")
@@ -693,9 +716,9 @@ def train_begin(self, estimator, *args, **kwargs):
693716
def epoch_end(self, estimator, *args, **kwargs):
694717
monitor_name, monitor_value = self.monitor.get()
695718
if np.isnan(monitor_value):
696-
warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
697-
'as monitor, you can use estimator.prepare_loss_and_metrics to'
698-
'create all metric objects', monitor_name))
719+
warnings.warn(RuntimeWarning(
720+
'%s is not updated, make sure you pass one of the metric objects from'
721+
'estimator.train_metrics and estimator.val_metrics as monitor.', monitor_name))
699722
else:
700723
if self.monitor_op(monitor_value - self.min_delta, self.best):
701724
self.best = monitor_value

0 commit comments

Comments
 (0)