Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2ad3224

Browse files
author
Sheng Zha
committed
fix composite metric case in handlers
1 parent 9396030 commit 2ad3224

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

python/mxnet/gluon/contrib/estimator/estimator.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
2626
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
27+
from .utils import _check_metrics
2728
from ...data import DataLoader
2829
from ...loss import SoftmaxCrossEntropyLoss
2930
from ...loss import Loss as gluon_loss
@@ -68,7 +69,7 @@ def __init__(self, net,
6869

6970
self.net = net
7071
self.loss = self._check_loss(loss)
71-
self.train_metrics = self._check_metrics(metrics)
72+
self.train_metrics = _check_metrics(metrics)
7273

7374
self.context = self._check_context(context)
7475
self._initialize(initializer)
@@ -84,18 +85,6 @@ def _check_loss(self, loss):
8485
"refer to gluon.loss.Loss:{}".format(loss))
8586
return loss
8687

87-
def _check_metrics(self, metrics):
88-
if isinstance(metrics, CompositeEvalMetric):
89-
metrics = metrics.metrics
90-
elif isinstance(metrics, EvalMetric):
91-
metrics = [metrics]
92-
else:
93-
metrics = metrics or []
94-
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
95-
raise ValueError("metrics must be a Metric or a list of Metric, "
96-
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
97-
return metrics
98-
9988
def _check_context(self, context):
10089
# infer available context
10190
gpus = num_gpus()

python/mxnet/gluon/contrib/estimator/event_handler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
import numpy as np
2828

29-
from ....metric import EvalMetric
29+
from ....metric import EvalMetric, CompositeEvalMetric
3030
from ....metric import Loss as metric_loss
31+
from .utils import _check_metrics
3132

3233
__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
3334
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
@@ -118,7 +119,7 @@ class MetricHandler(EpochBegin, BatchEnd):
118119
"""
119120

120121
def __init__(self, train_metrics):
121-
self.train_metrics = train_metrics or []
122+
self.train_metrics = _check_metrics(train_metrics)
122123
# order to be called among all callbacks
123124
# metrics need to be calculated before other callbacks can access them
124125
self.priority = -np.Inf
@@ -173,7 +174,7 @@ def __init__(self,
173174
self.eval_fn = eval_fn
174175
self.epoch_period = epoch_period
175176
self.batch_period = batch_period
176-
self.val_metrics = val_metrics
177+
self.val_metrics = _check_metrics(val_metrics)
177178
self.current_batch = 0
178179
self.current_epoch = 0
179180
# order to be called among all callbacks
@@ -255,8 +256,8 @@ def __init__(self, file_name=None,
255256
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
256257
% verbose)
257258
self.verbose = verbose
258-
self.train_metrics = train_metrics or []
259-
self.val_metrics = val_metrics or []
259+
self.train_metrics = _check_metrics(train_metrics)
260+
self.val_metrics = _check_metrics(val_metrics)
260261
self.batch_index = 0
261262
self.current_epoch = 0
262263
self.processed_samples = 0
@@ -637,6 +638,9 @@ def __init__(self,
637638
if not isinstance(monitor, EvalMetric):
638639
raise ValueError("Please provide one of the metric objects as monitor, "
639640
"You can create these objects using estimator.prepare_loss_and_metric()")
641+
if isinstance(monitor, CompositeEvalMetric):
642+
raise ValueError("CompositeEvalMetric is not supported for EarlyStoppingHandler, "
643+
"please specify a simple metric instead.")
640644
self.monitor = monitor
641645
self.baseline = baseline
642646
self.patience = patience
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=wildcard-import, unused-variable
20+
"""Gluon Estimator Utility Functions"""
21+
22+
from ....metric import EvalMetric, CompositeEvalMetric
23+
24+
def _check_metrics(metrics):
25+
if isinstance(metrics, CompositeEvalMetric):
26+
metrics = [m for metric in metrics.metrics for m in _check_metrics(metric)]
27+
elif isinstance(metrics, EvalMetric):
28+
metrics = [metrics]
29+
else:
30+
metrics = metrics or []
31+
if not all([isinstance(metric, EvalMetric) for metric in metrics]):
32+
raise ValueError("metrics must be a Metric or a list of Metric, "
33+
"refer to mxnet.metric.EvalMetric:{}".format(metrics))
34+
return metrics

tests/nightly/estimator/test_sentiment_rnn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,13 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
191191
# Define loss and evaluation metrics
192192
loss = gluon.loss.SoftmaxCrossEntropyLoss()
193193
metrics = mx.metric.CompositeEvalMetric()
194-
metrics.add([mx.metric.Accuracy(), mx.metric.Loss()])
194+
acc = mx.metric.Accuracy()
195+
nested_metrics = mx.metric.CompositeEvalMetric()
196+
metrics.add([acc, mx.metric.Loss()])
197+
nested_metrics.add([metrics, mx.metric.Accuracy()])
195198

196199
# Define estimator
197-
est = estimator.Estimator(net=net, loss=loss, metrics=metrics,
200+
est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics,
198201
trainer=trainer, context=ctx)
199202
# Begin training
200203
est.fit(train_data=train_dataloader, val_data=test_dataloader,

0 commit comments

Comments
 (0)