Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 69cd3a1

Browse files
stephenyan1231facebook-github-bot
authored andcommitted
mixup data augmentation
Summary: This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Differential Revision: D20911088 fbshipit-source-id: 36a958ef4f711d122064fae736fed7a7e91b81e8
1 parent 6214d10 commit 69cd3a1

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from classy_vision.generic.util import convert_to_one_hot
9+
from torch.distributions.beta import Beta
10+
11+
12+
def mixup_transform(sample, num_classes, alpha):
13+
"""
14+
This implements the mixup data augmentation in the paper
15+
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412)
16+
17+
Args:
18+
sample (Dict[str, Any]): the batch data
19+
alpha (float): the hyperparameter of Beta distribution used to sample mixup
20+
coefficient.
21+
"""
22+
assert (
23+
sample["target"].ndim == 1
24+
), "Currently mixup only supports single-label classification"
25+
sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1), num_classes)
26+
27+
c = Beta(torch.tensor([alpha]), torch.tensor([alpha])).sample()
28+
29+
for key in ["input", "target"]:
30+
sample[key] = c * sample[key] + (1.0 - c) * sample[key].flip([0])
31+
32+
return sample

classy_vision/generic/util.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -736,11 +736,10 @@ def maybe_convert_to_one_hot(target, model_output):
736736
):
737737
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])
738738

739-
assert (target.shape == model_output.shape) and (
740-
torch.min(target.eq(0) + target.eq(1)) == 1
741-
), (
742-
"Target must be one-hot/multi-label encoded and of the "
743-
"same shape as model_output."
739+
# target can be not necessarily hard 0/1 encoding. It can be soft
740+
# (i.e. fractional) such as mixup label
741+
assert target.shape == model_output.shape, (
742+
"Target must of the " "same shape as model_output."
744743
)
745744

746745
return target

classy_vision/tasks/classification_task.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
import torch
1414
import torch.nn as nn
1515
from classy_vision.dataset import ClassyDataset, build_dataset
16+
from classy_vision.dataset.transforms.mixup import mixup_transform
1617
from classy_vision.generic.distributed_util import (
1718
all_reduce_mean,
1819
barrier,
1920
init_distributed_data_parallel_model,
2021
is_distributed_training_run,
2122
)
2223
from classy_vision.generic.util import (
24+
convert_to_one_hot,
2325
copy_model_to_gpu,
2426
recursive_copy_to_gpu,
2527
update_classy_state,
@@ -139,6 +141,7 @@ def __init__(self):
139141
BroadcastBuffersMode.DISABLED
140142
)
141143
self.amp_args = None
144+
self.mixup_args = None
142145
self.perf_log = []
143146
self.last_batch = None
144147
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
@@ -306,6 +309,20 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
306309
logging.info(f"AMP enabled with args {amp_args}")
307310
return self
308311

312+
def set_mixup_args(self, mixup_args: Optional[Dict[str, Any]]):
313+
"""Disable / enable mixup data augmentation
314+
315+
Args::
316+
mixup_args: expect to include the follow keys in the dictionary
317+
num_classes (int): number of dataset classes
318+
alpha (float): the hyperparameter of Beta distribution used to
319+
sample mixup coefficient.
320+
"""
321+
self.mixup_args = mixup_args
322+
if mixup_args is None:
323+
logging.info(f"mixup disabled")
324+
return self
325+
309326
@classmethod
310327
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
311328
"""Instantiates a ClassificationTask from a configuration.
@@ -348,6 +365,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
348365
.set_optimizer(optimizer)
349366
.set_meters(meters)
350367
.set_amp_args(amp_args)
368+
.set_mixup_args(config.get("mixup"))
351369
.set_distributed_options(
352370
broadcast_buffers_mode=BroadcastBuffersMode[
353371
config.get("broadcast_buffers", "disabled").upper()
@@ -697,6 +715,11 @@ def eval_step(self, use_gpu):
697715
+ "'target' keys"
698716
)
699717

718+
if self.mixup_args is not None:
719+
sample["target"] = convert_to_one_hot(
720+
sample["target"].view(-1, 1), self.mixup_args["num_classes"]
721+
)
722+
700723
# Copy sample to GPU
701724
target = sample["target"]
702725
if use_gpu:
@@ -743,6 +766,11 @@ def train_step(self, use_gpu):
743766
+ "'target' keys"
744767
)
745768

769+
if self.mixup_args is not None:
770+
sample = mixup_transform(
771+
sample, self.mixup_args["num_classes"], self.mixup_args["alpha"]
772+
)
773+
746774
# Copy sample to GPU
747775
target = sample["target"]
748776
if use_gpu:

0 commit comments

Comments
 (0)