Skip to content

PyTorch: fails on model with multiple return values. #1147

Open
@sei-jgwohlbier

Description

@sei-jgwohlbier

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

hls4ml fails on a PyTorch model with multiple return values.

Details

hls4ml fails on code below that has two linear layers and returns output from both layers.

Steps to Reproduce

Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [cc4fbf9]
  3. Run conversion for code below.
from pathlib import Path

import numpy as np
import os
import shutil
import torch
import torch.nn as nn
from torchinfo import summary

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model

test_root_path = Path(__file__).parent

class test(nn.Module):
    def __init__(self, n_in, n1, n2):
        super().__init__()

        self.lin1 = nn.Linear(n_in, n1, bias=True)
        self.lin2 = nn.Linear(n_in, n2, bias=True)

    def forward(self, x):
        y = self.lin1(x)
        z = self.lin2(x)
        return y, z

if __name__ == "__main__":

    n_batch = 16
    n_in = 16
    n1 = 32
    n2 = 64
    X_input_shape = (n_batch, n_in)

    model = test(n_in, n1, n2)
    io_type='io_stream'
    backend='Vitis'
    output_dir = str(test_root_path / f'hls4mlprj_2lin_{backend}_{io_type}')
    if os.path.exists(output_dir):
        print("delete project dir")
        shutil.rmtree(output_dir)

    model.eval()
    summary(model, input_size=X_input_shape)

    X_input = np.random.rand(*X_input_shape)
    #X_input = np.ones(X_input_shape)
    with torch.no_grad():
        pytorch_prediction = [p.detach().numpy()
                              for p in model(torch.Tensor(X_input))]

    # transform X_input to channels last
    X_input_hls = np.ascontiguousarray(X_input)

    # write tb data
    ipf = "./tb_input_features.dat"
    if os.path.isfile(ipf):
        os.remove(ipf)
    with open(ipf, "ab") as f:
        for x in X_input_hls:
            np.savetxt(f, x.flatten(), newline=" ")
    opf = "./tb_output_predictions.dat"
    if os.path.isfile(opf):
        os.remove(opf)
    with open(opf, "ab") as f:
        for p0,p1 in zip(pytorch_prediction[0],
                         pytorch_prediction[1]):
            np.savetxt(f, p0.flatten(), newline=" ")
            np.savetxt(f, p1.flatten(), newline=" ")

    default_precision='ap_fixed<16,6>'
    default_precision='ap_fixed<32,12>'
    #default_precision='ap_fixed<64,24>'
    config = config_from_pytorch_model(model,
                                       input_shape=X_input_shape[-1:],
                                       backend=backend,
                                       default_precision=default_precision,
                                       default_reuse_factor=1,
                                       channels_last_conversion='internal',
                                       transpose_outputs=False)
    config['Model']['Strategy'] = 'Resource'
    print(config)
    print(output_dir)

    hls_model = convert_from_pytorch_model(
        model,
        output_dir=output_dir,
        input_data_tb=ipf,
        output_data_tb=opf,
        backend=backend,
        hls_config=config,
        io_type=io_type,
        part='xcvu9p-flga2104-2-e'
    )
    hls_model.compile()

    print("pytorch_prediction")
    print(pytorch_prediction)

    # reshape hls prediction to channels last, then transpose
    hls_prediction = hls_model.predict(X_input_hls)
    print("hls_prediction")
    print(hls_prediction)

    rtol = 1.0e-2
    atol = 1.0e-2
    assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"

    for p0, h0 in zip(pytorch_prediction[0], hls_prediction[0]):
        np.testing.assert_allclose(p0,
                                   h0,
                                   rtol=rtol, atol=atol)
    for p1, h1 in zip(pytorch_prediction[1], hls_prediction[1]):
        np.testing.assert_allclose(p1,
                                   h1,
                                   rtol=rtol, atol=atol)
    # synthesize
    hls_model.build(csim=True, synth=True, cosim=True, validation=True)

Expected behavior

Sucessful synthesis.

Actual behavior

python test_2lin.py 
2024-12-12 18:12:01.988121: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 18:12:02.040888: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-12 18:12:02.892484: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
delete project dir
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
test                                     [16, 32]                  --
├─Linear: 1-1                            [16, 32]                  544
├─Linear: 1-2                            [16, 64]                  1,088
==========================================================================================
Total params: 1,632
Trainable params: 1,632
Non-trainable params: 0
Total mult-adds (M): 0.03
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.01
Estimated Total Size (MB): 0.02
==========================================================================================
{'Model': {'Precision': {'default': 'ap_fixed<32,12>'}, 'ReuseFactor': 1, 'ChannelsLastConversion': 'internal', 'TransposeOutputs': False, 'Strategy': 'Resource', 'BramFactor': 1000000000, 'TraceOutput': False}, 'PytorchModel': test(
  (lin1): Linear(in_features=16, out_features=32, bias=True)
  (lin2): Linear(in_features=16, out_features=64, bias=True)
), 'InputShape': (16,)}
/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/hls4mlprj_2lin_Vitis_io_stream
Interpreting Model ...
Topology:
Layer name: lin1, layer type: Dense, input shape: [[None, 16]]
Layer name: lin2, layer type: Dense, input shape: [[None, 16]]
Creating HLS model
Writing HLS project
Done
pytorch_prediction
[array([[ 0.8891059 , -0.44395483, -0.05747134,  0.18016575,  0.04786346,
         0.5514014 , -0.16852657,  0.00964493,  0.3273672 ,  0.42060843,
         0.06706502,  0.15498346,  0.2457329 , -0.15184441,  0.09685186,
        -0.6596167 ,  0.2790345 ,  0.40409216, -0.23034032,  0.26463172,
        -0.46979874, -0.11001211,  0.35551935,  0.09460301, -0.10833421,
        -0.4492357 , -0.28191066, -0.26569235, -0.12289155, -0.5352483 ,
         0.5751673 , -0.2317074 ],
       [ 0.76985276, -0.1049726 ,  0.07535005,  0.10176656,  0.09320992,
         0.28592783, -0.04348151,  0.03626189,  0.00936881,  0.4154517 ,
         0.37312955,  0.14893359,  0.1893669 , -0.22227341, -0.08566827,
        -0.58724916, -0.25961903,  0.65872145, -0.34750587,  0.315466  ,
        -0.47754753,  0.00142413,  0.28266868,  0.29222852,  0.03858323,
        -0.2562538 , -0.35519725,  0.18092948, -0.07686479, -0.6124781 ,
         0.31275678, -0.27654955],
       [ 0.85511225, -0.24324456, -0.00846682,  0.22668763,  0.03866964,
         0.2381162 ,  0.0457862 ,  0.05140384,  0.04126164,  0.32769206,
         0.2474686 ,  0.20140621,  0.0427063 , -0.2780934 , -0.01511081,
        -0.58079964, -0.19579059,  0.5630405 , -0.37406617,  0.4501509 ,
        -0.47031593, -0.1698381 ,  0.46137342,  0.19732511, -0.02999109,
        -0.27819347, -0.33464026,  0.00155999, -0.07882459, -0.51667523,
         0.4007225 , -0.1982677 ],
       [ 1.0654068 , -0.30002496,  0.16547504,  0.2570746 ,  0.07158022,
         0.4347104 , -0.06412914,  0.16597775,  0.16384612,  0.4160096 ,
         0.08880451,  0.1005227 ,  0.1824699 , -0.19954087,  0.34508896,
        -0.53782004,  0.09642816,  0.8185116 , -0.34626994,  0.471716  ,
        -0.5092526 , -0.06822003,  0.3831837 , -0.01965211, -0.01387932,
        -0.37834692, -0.3783682 , -0.3562213 , -0.27375486, -0.6525427 ,
         0.6037679 ,  0.17533389],
       [ 0.52793086,  0.08541805,  0.03517117, -0.4244916 ,  0.10885802,
         0.43530622,  0.3118299 , -0.01598971,  0.3790553 ,  0.5554543 ,
         0.05826975,  0.11390461,  0.2410459 ,  0.0613706 ,  0.26139343,
        -0.27970743,  0.26997155,  0.46432167,  0.00322317, -0.15576953,
        -0.340056  , -0.08219175,  0.24044743, -0.10614166,  0.1167696 ,
        -0.38514078, -0.20315412, -0.13610272, -0.13506019, -0.39643157,
         0.43387794, -0.22893703],
       [ 0.83738965,  0.01773065,  0.01746632,  0.0049476 , -0.02727026,
         0.17095442,  0.26207945,  0.1697861 ,  0.34357035,  0.2642256 ,
         0.29654276,  0.2556939 ,  0.06309891, -0.10552   ,  0.08774575,
        -0.5153153 , -0.06944568,  0.31070724, -0.21419683,  0.21724322,
        -0.45854414, -0.04687934,  0.29160213,  0.29456928,  0.14869723,
        -0.2757703 , -0.3541801 ,  0.08705469, -0.09899832, -0.37215212,
         0.6330352 , -0.5796311 ],
       [ 0.58500886,  0.2640052 , -0.0189429 , -0.2794629 ,  0.13246663,
        -0.267674  ,  0.24941778, -0.04296389,  0.15840055,  0.01208394,
         0.1177678 ,  0.39987636,  0.08620736, -0.03397053,  0.12804201,
        -0.65928245, -0.05545972,  0.69912994, -0.16601579,  0.18794903,
        -0.7339839 ,  0.03901096,  0.30852503, -0.01032168,  0.08174405,
        -0.27913028, -0.23137385,  0.00499156,  0.09213072, -0.759608  ,
         0.91822934, -0.5346441 ],
       [ 0.56878453,  0.11198848,  0.05960959, -0.12241329,  0.12977597,
         0.08147588,  0.3533719 ,  0.16589719, -0.06445619,  0.4639053 ,
         0.462967  ,  0.2932239 ,  0.13533969, -0.2153621 ,  0.14075479,
        -0.5042372 , -0.26714593,  0.48706523, -0.33529457,  0.3466363 ,
        -0.34024402, -0.11915696,  0.3307217 ,  0.36106905,  0.18427725,
        -0.20332047, -0.35370904,  0.09610368, -0.09901851, -0.46560162,
         0.5428121 , -0.42115593],
       [ 0.6866636 , -0.11477496, -0.01831607, -0.02805769, -0.01344819,
         0.68515   , -0.04925771, -0.1972478 ,  0.31140062,  0.40612757,
         0.2530442 ,  0.21337444,  0.6395557 , -0.09065704,  0.19372801,
        -0.45185912,  0.50229234,  0.3983093 , -0.14483142,  0.07841846,
        -0.49485716, -0.02537266,  0.29264736,  0.1069174 ,  0.11361703,
        -0.20951061, -0.26409623, -0.32147884, -0.02064542, -0.511343  ,
         0.38575244,  0.04794483],
       [ 0.67789733, -0.30219573, -0.12434944,  0.11558396,  0.06762291,
         0.3116027 ,  0.15201744,  0.15036204,  0.06727348,  0.42700085,
         0.3871081 ,  0.3823516 ,  0.24762037, -0.17611447,  0.13901351,
        -0.53381497,  0.14353468,  0.49727145, -0.15057111,  0.32427734,
        -0.40415937, -0.0112884 ,  0.31515226,  0.16169925,  0.0040657 ,
        -0.19852826, -0.22190264, -0.18831491, -0.13794504, -0.5023341 ,
         0.69033474, -0.38809985],
       [ 0.81814086, -0.05858143,  0.07434263,  0.01335097,  0.01402169,
         0.5058249 ,  0.16288948, -0.10923052,  0.21130612,  0.52750933,
         0.12909204,  0.04708859,  0.51017034, -0.48467067,  0.236849  ,
        -0.53199136,  0.5741215 ,  0.66357696, -0.08221325,  0.04293117,
        -0.21072648,  0.13694671,  0.34113112,  0.00190126,  0.07781912,
        -0.01927111, -0.48293623, -0.401911  , -0.00399454, -0.6709269 ,
         0.76886785, -0.07476626],
       [ 0.6913032 , -0.1981028 , -0.08275409,  0.10008418, -0.07262716,
         0.36380088,  0.08553496, -0.16448833,  0.21087572,  0.53764087,
         0.19602291,  0.09081438,  0.2667737 , -0.33534533, -0.2282128 ,
        -0.5492    ,  0.21781437,  0.72637093, -0.14848016,  0.04423207,
        -0.2934043 ,  0.05480177,  0.37749898,  0.06654172,  0.00630023,
        -0.1037505 , -0.30250746, -0.19109204, -0.05297701, -0.60283387,
         0.3600675 , -0.24646895],
       [ 0.81141824, -0.33150065,  0.06518545, -0.02965383,  0.24818233,
         0.43532866,  0.10186243,  0.38129237,  0.31362757,  0.4576823 ,
         0.2271005 ,  0.2522714 ,  0.22506769,  0.23416433,  0.22466838,
        -0.17880167, -0.13405931,  0.50129116, -0.32637626,  0.40901405,
        -0.30316994, -0.19033791,  0.07678111,  0.04296248,  0.0158764 ,
        -0.37330478, -0.03181724, -0.10275158, -0.16623837, -0.33406097,
         0.5428152 , -0.32601902],
       [ 0.93497413, -0.09970862,  0.05688836,  0.05871909,  0.20382336,
         0.2938374 , -0.08205896,  0.13411754,  0.06540591,  0.35490224,
         0.16667452,  0.32861888,  0.15358543, -0.31738937,  0.5358051 ,
        -0.7712355 ,  0.17287332,  0.72606945, -0.15157634,  0.25406668,
        -0.7806479 , -0.15099436,  0.36042407, -0.00913737,  0.08302039,
        -0.4754683 , -0.47949797, -0.18461673, -0.27660465, -0.78692245,
         0.6459955 ,  0.0683116 ],
       [ 0.689624  ,  0.2917918 ,  0.23204178, -0.24007678,  0.15531424,
         0.06384554,  0.05830847, -0.05809493,  0.2642526 , -0.0223131 ,
         0.05240817,  0.23094033,  0.34774926, -0.06030922,  0.3801004 ,
        -0.45984933,  0.06846657,  0.45807868, -0.28368205,  0.19050267,
        -0.60009193,  0.00462461,  0.19878045,  0.07578797,  0.16144395,
        -0.32184047, -0.33067778, -0.05224532,  0.04741496, -0.5540669 ,
         0.75985986, -0.11818236],
       [ 0.42422694, -0.13985819,  0.15179643, -0.0994494 ,  0.05480643,
         0.4122148 ,  0.03583536, -0.03997545,  0.0027165 ,  0.5133945 ,
         0.23952836,  0.02225108,  0.21865284,  0.06876859,  0.21143112,
        -0.6492269 ,  0.18834093,  0.43246025, -0.3369369 ,  0.12497532,
        -0.4983435 , -0.05300211,  0.35259238,  0.36499974, -0.01017824,
        -0.51138484, -0.400999  , -0.19466041, -0.20390618, -0.6352968 ,
         0.6848689 , -0.19771972]], dtype=float32), array([[-0.64182013,  0.40184742, -0.07485253, ...,  0.11808984,
        -0.22370778,  0.11206529],
       [-0.5201124 ,  0.18107897, -0.02317163, ..., -0.2986831 ,
         0.10982952,  0.18085258],
       [-0.50899744,  0.2951257 , -0.13951811, ..., -0.02748931,
        -0.06415796,  0.18319744],
       ...,
       [-0.63268775,  0.62129   , -0.16334109, ..., -0.16879866,
        -0.31751645,  0.15808316],
       [-0.71309054,  0.2369004 ,  0.10258856, ..., -0.11202749,
        -0.46535045,  0.2605038 ],
       [-0.71718186,  0.42549616,  0.10741499, ..., -0.22839454,
        -0.1128529 ,  0.39642552]], dtype=float32)]
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
hls_prediction
[[-0.64183044  0.40183544 -0.07486248 ...  0.11807919 -0.22371769
   0.11205196]
 [-0.52012253  0.18106556 -0.02318382 ... -0.29869556  0.10981655
   0.18084049]
 [-0.50900841  0.29511356 -0.13953018 ... -0.02750206 -0.06417084
   0.18318558]
 ...
 [-0.63269997  0.62127781 -0.16335392 ... -0.1688118  -0.31753063
   0.15807152]
 [-0.71309853  0.23688602  0.10257721 ... -0.11204052 -0.46536255
   0.26049042]
 [-0.7171917   0.42548656  0.10740471 ... -0.22840595 -0.11286354
   0.3964119 ]]
Traceback (most recent call last):
  File "/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/test_2lin.py", line 107, in <module>
    assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"
AssertionError: length mismatch

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions