Skip to content

Error on config_from_pytorch_model when using two conv2d after upsample #1184

Closed
@Dyoxyz

Description

@Dyoxyz

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

hls4ml config_from_pytorch_model fails on second conv2d after using upsample.

Details

hls4ml config_from_pytorch_model fails on code below.

Steps to Reproduce

Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [1ad1ad9]
  3. Run code below
import torch.nn as nn
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

class test_model(nn.Module):
    def __init__(self):
        super(test_model, self).__init__()

        self.conv1 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2, 1, kernel_size=3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)

        for m in self.modules():
            init_weights(m)

    def forward(self, x):
        x1 = F.leaky_relu(self.conv1(x))

        x1u = F.leaky_relu(self.upsample1(x1))

        x2 = F.leaky_relu(self.conv2(x1u))

        x3 = F.leaky_relu(self.conv3(x2))

        return x3

model = test_model()

import hls4ml
from hls4ml.utils.config import config_from_pytorch_model

config = config_from_pytorch_model(model, input_shape=(None, 1, 128, 128), default_precision='ap_fixed<16,6>', default_reuse_factor=1, channels_last_conversion='internal', transpose_outputs=False)
config['Model']['Strategy'] = 'Resource'

Expected behavior

Sucessfull creation of config from model.

Actual behavior

Traceback (most recent call last):
  File "/home/user/hls4ml/test1.py", line 43, in <module>
    config = config_from_pytorch_model(model, input_shape=(None, 1, 128, 128), default_precision='ap_fixed<16,6>', default_reuse_factor=1, channels_last_conversion='internal', transpose_outputs=False)
  File "/home/user/hls4ml/hls4ml/utils/config.py", line 371, in config_from_pytorch_model
    ) = parse_pytorch_model(config, verbose=False)
  File "/home/user/hls4ml/hls4ml/converters/pytorch_to_hls.py", line 246, in parse_pytorch_model
    layer, output_shape = layer_handlers[pytorch_class](
  File "/home/user/hls4ml/hls4ml/converters/pytorch/convolution.py", line 68, in parse_conv2d_layer
    (layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(
TypeError: cannot unpack non-iterable NoneType object

Optional

Additional context

Creation of config from model is working when using only one conv2d after upsample (see code below).

import torch.nn as nn
import torch.nn.functional as F

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

class test_model(nn.Module):
    def __init__(self):
        super(test_model, self).__init__()

        self.conv1 = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2, 1, kernel_size=3, stride=1, padding=1)

        self.upsample1 = nn.Upsample(scale_factor=2)

        for m in self.modules():
            init_weights(m)

    def forward(self, x):
        x1 = F.leaky_relu(self.conv1(x))

        x1u = F.leaky_relu(self.upsample1(x1))

        x2 = F.leaky_relu(self.conv2(x1u))

        return x2

model = test_model()

import hls4ml
from hls4ml.utils.config import config_from_pytorch_model

config = config_from_pytorch_model(model, input_shape=(None, 1, 128, 128), default_precision='ap_fixed<16,6>', default_reuse_factor=1, channels_last_conversion='internal', transpose_outputs=False)
config['Model']['Strategy'] = 'Resource'

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions