Skip to content

Commit e7a7422

Browse files
masadcvyiheng-wang-nvwyli
authored
Adding EfficientNetsB0-B7 support (#1938)
* adding init efficientnet support Signed-off-by: masadcv <[email protected]> * fixing flake8 and further refactoring Signed-off-by: masadcv <[email protected]> * adding unittests for efficiennet Signed-off-by: masadcv <[email protected]> * making unittests backwards compatible python<3.8 Signed-off-by: masadcv <[email protected]> * fixed kitty unittests file path Signed-off-by: masadcv <[email protected]> * adding docstrings and minor refactoring Signed-off-by: masadcv <[email protected]> * fix flake8-py3 failing test Signed-off-by: masadcv <[email protected]> * generalize drop_connect for n-dim, fix/add unittests, remove assert Signed-off-by: masadcv <[email protected]> * fix failing unittest, CC0-license image for test Signed-off-by: masadcv <[email protected]> * refactoring code for review Signed-off-by: masadcv <[email protected]> * WIP fix mypy type hint errors Signed-off-by: masadcv <[email protected]> * fix cuda test error Signed-off-by: masadcv <[email protected]> * WIP fix test errors Signed-off-by: masadcv <[email protected]> * adding non-default shape tests Signed-off-by: masadcv <[email protected]> * remove 3d case from non-default shape test Signed-off-by: masadcv <[email protected]> * refactoring and updating docs Signed-off-by: masadcv <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent d56c002 commit e7a7422

File tree

10 files changed

+1239
-3
lines changed

10 files changed

+1239
-3
lines changed

docs/source/networks.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ Blocks
3535
.. autoclass:: Swish
3636
:members:
3737

38+
`MemoryEfficientSwish`
39+
~~~~~~~~~~~~~~~~~~~~~~
40+
.. autoclass:: MemoryEfficientSwish
41+
:members:
42+
3843
`Mish`
3944
~~~~~~
4045
.. autoclass:: Mish
@@ -292,6 +297,11 @@ Nets
292297
.. autoclass:: DenseNet
293298
:members:
294299

300+
`EfficientNet`
301+
~~~~~~~~~~~~~~
302+
.. autoclass:: EfficientNet
303+
:members:
304+
295305
`SegResNet`
296306
~~~~~~~~~~~
297307
.. autoclass:: SegResNet

monai/networks/blocks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
from .acti_norm import ADN
13-
from .activation import Mish, Swish
13+
from .activation import MemoryEfficientSwish, Mish, Swish
1414
from .aspp import SimpleASPP
1515
from .convolutions import Convolution, ResidualUnit
1616
from .crf import CRF

monai/networks/blocks/activation.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Swish(nn.Module):
1717
r"""Applies the element-wise function:
1818
1919
.. math::
20-
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) for constant value alpha.
20+
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.
2121
2222
Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.
2323
@@ -43,6 +43,57 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
4343
return input * torch.sigmoid(self.alpha * input)
4444

4545

46+
class SwishImplementation(torch.autograd.Function):
47+
r"""Memory efficient implementation for training
48+
Follows recommendation from:
49+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853
50+
51+
Results in ~ 30% memory saving during training as compared to Swish()
52+
"""
53+
54+
@staticmethod
55+
def forward(ctx, input):
56+
result = input * torch.sigmoid(input)
57+
ctx.save_for_backward(input)
58+
return result
59+
60+
@staticmethod
61+
def backward(ctx, grad_output):
62+
input = ctx.saved_tensors[0]
63+
sigmoid_input = torch.sigmoid(input)
64+
return grad_output * (sigmoid_input * (1 + input * (1 - sigmoid_input)))
65+
66+
67+
class MemoryEfficientSwish(nn.Module):
68+
r"""Applies the element-wise function:
69+
70+
.. math::
71+
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.
72+
73+
Memory efficient implementation for training following recommendation from:
74+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853
75+
76+
Results in ~ 30% memory saving during training as compared to Swish()
77+
78+
Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.
79+
80+
Shape:
81+
- Input: :math:`(N, *)` where `*` means, any number of additional
82+
dimensions
83+
- Output: :math:`(N, *)`, same shape as the input
84+
85+
86+
Examples::
87+
88+
>>> m = Act['memswish']()
89+
>>> input = torch.randn(2)
90+
>>> output = m(input)
91+
"""
92+
93+
def forward(self, input: torch.Tensor):
94+
return SwishImplementation.apply(input)
95+
96+
4697
class Mish(nn.Module):
4798
r"""Applies the element-wise function:
4899

monai/networks/layers/factories.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ def swish_factory():
256256
return Swish
257257

258258

259+
@Act.factory_function("memswish")
260+
def memswish_factory():
261+
from monai.networks.blocks.activation import MemoryEfficientSwish
262+
263+
return MemoryEfficientSwish
264+
265+
259266
@Act.factory_function("mish")
260267
def mish_factory():
261268
from monai.networks.blocks.activation import Mish

monai/networks/nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .classifier import Classifier, Critic, Discriminator
1616
from .densenet import DenseNet, DenseNet121, DenseNet169, DenseNet201, DenseNet264
1717
from .dynunet import DynUNet, DynUnet, Dynunet
18+
from .efficientnet import EfficientNet, EfficientNetBN, drop_connect, get_efficientnet_image_size
1819
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
1920
from .generator import Generator
2021
from .highresnet import HighResBlock, HighResNet

0 commit comments

Comments
 (0)