Skip to content

Adding EfficientNetsB0-B7 support #1938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Blocks
.. autoclass:: Swish
:members:

`MemoryEfficientSwish`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MemoryEfficientSwish
:members:

`Mish`
~~~~~~
.. autoclass:: Mish
Expand Down Expand Up @@ -292,6 +297,11 @@ Nets
.. autoclass:: DenseNet
:members:

`EfficientNet`
~~~~~~~~~~~~~~
.. autoclass:: EfficientNet
:members:

`SegResNet`
~~~~~~~~~~~
.. autoclass:: SegResNet
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from .acti_norm import ADN
from .activation import Mish, Swish
from .activation import MemoryEfficientSwish, Mish, Swish
from .aspp import SimpleASPP
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
Expand Down
53 changes: 52 additions & 1 deletion monai/networks/blocks/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Swish(nn.Module):
r"""Applies the element-wise function:

.. math::
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) for constant value alpha.
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.

Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

Expand All @@ -43,6 +43,57 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(self.alpha * input)


class SwishImplementation(torch.autograd.Function):
r"""Memory efficient implementation for training
Follows recommendation from:
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

Results in ~ 30% memory saving during training as compared to Swish()
"""

@staticmethod
def forward(ctx, input):
result = input * torch.sigmoid(input)
ctx.save_for_backward(input)
return result

@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
sigmoid_input = torch.sigmoid(input)
return grad_output * (sigmoid_input * (1 + input * (1 - sigmoid_input)))


class MemoryEfficientSwish(nn.Module):
r"""Applies the element-wise function:

.. math::
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.

Memory efficient implementation for training following recommendation from:
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

Results in ~ 30% memory saving during training as compared to Swish()

Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input


Examples::

>>> m = Act['memswish']()
>>> input = torch.randn(2)
>>> output = m(input)
"""

def forward(self, input: torch.Tensor):
return SwishImplementation.apply(input)


class Mish(nn.Module):
r"""Applies the element-wise function:

Expand Down
7 changes: 7 additions & 0 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ def swish_factory():
return Swish


@Act.factory_function("memswish")
def memswish_factory():
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_function("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .classifier import Classifier, Critic, Discriminator
from .densenet import DenseNet, DenseNet121, DenseNet169, DenseNet201, DenseNet264
from .dynunet import DynUNet, DynUnet, Dynunet
from .efficientnet import EfficientNet, EfficientNetBN, drop_connect, get_efficientnet_image_size
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
Expand Down
Loading