diff --git a/timm/models/vgg.py b/timm/models/vgg.py index c096df23fb..a4cfbffdff 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -38,8 +38,8 @@ def __init__( kernel_size=7, mlp_ratio=1.0, drop_rate: float = 0.2, - act_layer: Optional[Type[nn.Module]] = None, - conv_layer: Optional[Type[nn.Module]] = None, + act_layer: Type[nn.Module] = nn.ReLU, + conv_layer: Type[nn.Module] = nn.Conv2d, ): super(ConvMlp, self).__init__() self.input_kernel_size = kernel_size