Open
Description
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
- Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
- Check that the issue hasn't already been reported, by checking the currently open issues.
- If there are steps to reproduce the problem, make sure to write them down below.
- If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.
Quick summary
hls4ml fails on a PyTorch model with multiple return values.
Details
hls4ml fails on code below that has two linear layers and returns output from both layers.
Steps to Reproduce
Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.
- Clone the hls4ml repository
- Checkout the master branch, with commit hash: [cc4fbf9]
- Run conversion for code below.
from pathlib import Path
import numpy as np
import os
import shutil
import torch
import torch.nn as nn
from torchinfo import summary
from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model
test_root_path = Path(__file__).parent
class test(nn.Module):
def __init__(self, n_in, n1, n2):
super().__init__()
self.lin1 = nn.Linear(n_in, n1, bias=True)
self.lin2 = nn.Linear(n_in, n2, bias=True)
def forward(self, x):
y = self.lin1(x)
z = self.lin2(x)
return y, z
if __name__ == "__main__":
n_batch = 16
n_in = 16
n1 = 32
n2 = 64
X_input_shape = (n_batch, n_in)
model = test(n_in, n1, n2)
io_type='io_stream'
backend='Vitis'
output_dir = str(test_root_path / f'hls4mlprj_2lin_{backend}_{io_type}')
if os.path.exists(output_dir):
print("delete project dir")
shutil.rmtree(output_dir)
model.eval()
summary(model, input_size=X_input_shape)
X_input = np.random.rand(*X_input_shape)
#X_input = np.ones(X_input_shape)
with torch.no_grad():
pytorch_prediction = [p.detach().numpy()
for p in model(torch.Tensor(X_input))]
# transform X_input to channels last
X_input_hls = np.ascontiguousarray(X_input)
# write tb data
ipf = "./tb_input_features.dat"
if os.path.isfile(ipf):
os.remove(ipf)
with open(ipf, "ab") as f:
for x in X_input_hls:
np.savetxt(f, x.flatten(), newline=" ")
opf = "./tb_output_predictions.dat"
if os.path.isfile(opf):
os.remove(opf)
with open(opf, "ab") as f:
for p0,p1 in zip(pytorch_prediction[0],
pytorch_prediction[1]):
np.savetxt(f, p0.flatten(), newline=" ")
np.savetxt(f, p1.flatten(), newline=" ")
default_precision='ap_fixed<16,6>'
default_precision='ap_fixed<32,12>'
#default_precision='ap_fixed<64,24>'
config = config_from_pytorch_model(model,
input_shape=X_input_shape[-1:],
backend=backend,
default_precision=default_precision,
default_reuse_factor=1,
channels_last_conversion='internal',
transpose_outputs=False)
config['Model']['Strategy'] = 'Resource'
print(config)
print(output_dir)
hls_model = convert_from_pytorch_model(
model,
output_dir=output_dir,
input_data_tb=ipf,
output_data_tb=opf,
backend=backend,
hls_config=config,
io_type=io_type,
part='xcvu9p-flga2104-2-e'
)
hls_model.compile()
print("pytorch_prediction")
print(pytorch_prediction)
# reshape hls prediction to channels last, then transpose
hls_prediction = hls_model.predict(X_input_hls)
print("hls_prediction")
print(hls_prediction)
rtol = 1.0e-2
atol = 1.0e-2
assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"
for p0, h0 in zip(pytorch_prediction[0], hls_prediction[0]):
np.testing.assert_allclose(p0,
h0,
rtol=rtol, atol=atol)
for p1, h1 in zip(pytorch_prediction[1], hls_prediction[1]):
np.testing.assert_allclose(p1,
h1,
rtol=rtol, atol=atol)
# synthesize
hls_model.build(csim=True, synth=True, cosim=True, validation=True)
Expected behavior
Sucessful synthesis.
Actual behavior
python test_2lin.py
2024-12-12 18:12:01.988121: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 18:12:02.040888: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-12 18:12:02.892484: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
delete project dir
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
test [16, 32] --
├─Linear: 1-1 [16, 32] 544
├─Linear: 1-2 [16, 64] 1,088
==========================================================================================
Total params: 1,632
Trainable params: 1,632
Non-trainable params: 0
Total mult-adds (M): 0.03
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.01
Estimated Total Size (MB): 0.02
==========================================================================================
{'Model': {'Precision': {'default': 'ap_fixed<32,12>'}, 'ReuseFactor': 1, 'ChannelsLastConversion': 'internal', 'TransposeOutputs': False, 'Strategy': 'Resource', 'BramFactor': 1000000000, 'TraceOutput': False}, 'PytorchModel': test(
(lin1): Linear(in_features=16, out_features=32, bias=True)
(lin2): Linear(in_features=16, out_features=64, bias=True)
), 'InputShape': (16,)}
/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/hls4mlprj_2lin_Vitis_io_stream
Interpreting Model ...
Topology:
Layer name: lin1, layer type: Dense, input shape: [[None, 16]]
Layer name: lin2, layer type: Dense, input shape: [[None, 16]]
Creating HLS model
Writing HLS project
Done
pytorch_prediction
[array([[ 0.8891059 , -0.44395483, -0.05747134, 0.18016575, 0.04786346,
0.5514014 , -0.16852657, 0.00964493, 0.3273672 , 0.42060843,
0.06706502, 0.15498346, 0.2457329 , -0.15184441, 0.09685186,
-0.6596167 , 0.2790345 , 0.40409216, -0.23034032, 0.26463172,
-0.46979874, -0.11001211, 0.35551935, 0.09460301, -0.10833421,
-0.4492357 , -0.28191066, -0.26569235, -0.12289155, -0.5352483 ,
0.5751673 , -0.2317074 ],
[ 0.76985276, -0.1049726 , 0.07535005, 0.10176656, 0.09320992,
0.28592783, -0.04348151, 0.03626189, 0.00936881, 0.4154517 ,
0.37312955, 0.14893359, 0.1893669 , -0.22227341, -0.08566827,
-0.58724916, -0.25961903, 0.65872145, -0.34750587, 0.315466 ,
-0.47754753, 0.00142413, 0.28266868, 0.29222852, 0.03858323,
-0.2562538 , -0.35519725, 0.18092948, -0.07686479, -0.6124781 ,
0.31275678, -0.27654955],
[ 0.85511225, -0.24324456, -0.00846682, 0.22668763, 0.03866964,
0.2381162 , 0.0457862 , 0.05140384, 0.04126164, 0.32769206,
0.2474686 , 0.20140621, 0.0427063 , -0.2780934 , -0.01511081,
-0.58079964, -0.19579059, 0.5630405 , -0.37406617, 0.4501509 ,
-0.47031593, -0.1698381 , 0.46137342, 0.19732511, -0.02999109,
-0.27819347, -0.33464026, 0.00155999, -0.07882459, -0.51667523,
0.4007225 , -0.1982677 ],
[ 1.0654068 , -0.30002496, 0.16547504, 0.2570746 , 0.07158022,
0.4347104 , -0.06412914, 0.16597775, 0.16384612, 0.4160096 ,
0.08880451, 0.1005227 , 0.1824699 , -0.19954087, 0.34508896,
-0.53782004, 0.09642816, 0.8185116 , -0.34626994, 0.471716 ,
-0.5092526 , -0.06822003, 0.3831837 , -0.01965211, -0.01387932,
-0.37834692, -0.3783682 , -0.3562213 , -0.27375486, -0.6525427 ,
0.6037679 , 0.17533389],
[ 0.52793086, 0.08541805, 0.03517117, -0.4244916 , 0.10885802,
0.43530622, 0.3118299 , -0.01598971, 0.3790553 , 0.5554543 ,
0.05826975, 0.11390461, 0.2410459 , 0.0613706 , 0.26139343,
-0.27970743, 0.26997155, 0.46432167, 0.00322317, -0.15576953,
-0.340056 , -0.08219175, 0.24044743, -0.10614166, 0.1167696 ,
-0.38514078, -0.20315412, -0.13610272, -0.13506019, -0.39643157,
0.43387794, -0.22893703],
[ 0.83738965, 0.01773065, 0.01746632, 0.0049476 , -0.02727026,
0.17095442, 0.26207945, 0.1697861 , 0.34357035, 0.2642256 ,
0.29654276, 0.2556939 , 0.06309891, -0.10552 , 0.08774575,
-0.5153153 , -0.06944568, 0.31070724, -0.21419683, 0.21724322,
-0.45854414, -0.04687934, 0.29160213, 0.29456928, 0.14869723,
-0.2757703 , -0.3541801 , 0.08705469, -0.09899832, -0.37215212,
0.6330352 , -0.5796311 ],
[ 0.58500886, 0.2640052 , -0.0189429 , -0.2794629 , 0.13246663,
-0.267674 , 0.24941778, -0.04296389, 0.15840055, 0.01208394,
0.1177678 , 0.39987636, 0.08620736, -0.03397053, 0.12804201,
-0.65928245, -0.05545972, 0.69912994, -0.16601579, 0.18794903,
-0.7339839 , 0.03901096, 0.30852503, -0.01032168, 0.08174405,
-0.27913028, -0.23137385, 0.00499156, 0.09213072, -0.759608 ,
0.91822934, -0.5346441 ],
[ 0.56878453, 0.11198848, 0.05960959, -0.12241329, 0.12977597,
0.08147588, 0.3533719 , 0.16589719, -0.06445619, 0.4639053 ,
0.462967 , 0.2932239 , 0.13533969, -0.2153621 , 0.14075479,
-0.5042372 , -0.26714593, 0.48706523, -0.33529457, 0.3466363 ,
-0.34024402, -0.11915696, 0.3307217 , 0.36106905, 0.18427725,
-0.20332047, -0.35370904, 0.09610368, -0.09901851, -0.46560162,
0.5428121 , -0.42115593],
[ 0.6866636 , -0.11477496, -0.01831607, -0.02805769, -0.01344819,
0.68515 , -0.04925771, -0.1972478 , 0.31140062, 0.40612757,
0.2530442 , 0.21337444, 0.6395557 , -0.09065704, 0.19372801,
-0.45185912, 0.50229234, 0.3983093 , -0.14483142, 0.07841846,
-0.49485716, -0.02537266, 0.29264736, 0.1069174 , 0.11361703,
-0.20951061, -0.26409623, -0.32147884, -0.02064542, -0.511343 ,
0.38575244, 0.04794483],
[ 0.67789733, -0.30219573, -0.12434944, 0.11558396, 0.06762291,
0.3116027 , 0.15201744, 0.15036204, 0.06727348, 0.42700085,
0.3871081 , 0.3823516 , 0.24762037, -0.17611447, 0.13901351,
-0.53381497, 0.14353468, 0.49727145, -0.15057111, 0.32427734,
-0.40415937, -0.0112884 , 0.31515226, 0.16169925, 0.0040657 ,
-0.19852826, -0.22190264, -0.18831491, -0.13794504, -0.5023341 ,
0.69033474, -0.38809985],
[ 0.81814086, -0.05858143, 0.07434263, 0.01335097, 0.01402169,
0.5058249 , 0.16288948, -0.10923052, 0.21130612, 0.52750933,
0.12909204, 0.04708859, 0.51017034, -0.48467067, 0.236849 ,
-0.53199136, 0.5741215 , 0.66357696, -0.08221325, 0.04293117,
-0.21072648, 0.13694671, 0.34113112, 0.00190126, 0.07781912,
-0.01927111, -0.48293623, -0.401911 , -0.00399454, -0.6709269 ,
0.76886785, -0.07476626],
[ 0.6913032 , -0.1981028 , -0.08275409, 0.10008418, -0.07262716,
0.36380088, 0.08553496, -0.16448833, 0.21087572, 0.53764087,
0.19602291, 0.09081438, 0.2667737 , -0.33534533, -0.2282128 ,
-0.5492 , 0.21781437, 0.72637093, -0.14848016, 0.04423207,
-0.2934043 , 0.05480177, 0.37749898, 0.06654172, 0.00630023,
-0.1037505 , -0.30250746, -0.19109204, -0.05297701, -0.60283387,
0.3600675 , -0.24646895],
[ 0.81141824, -0.33150065, 0.06518545, -0.02965383, 0.24818233,
0.43532866, 0.10186243, 0.38129237, 0.31362757, 0.4576823 ,
0.2271005 , 0.2522714 , 0.22506769, 0.23416433, 0.22466838,
-0.17880167, -0.13405931, 0.50129116, -0.32637626, 0.40901405,
-0.30316994, -0.19033791, 0.07678111, 0.04296248, 0.0158764 ,
-0.37330478, -0.03181724, -0.10275158, -0.16623837, -0.33406097,
0.5428152 , -0.32601902],
[ 0.93497413, -0.09970862, 0.05688836, 0.05871909, 0.20382336,
0.2938374 , -0.08205896, 0.13411754, 0.06540591, 0.35490224,
0.16667452, 0.32861888, 0.15358543, -0.31738937, 0.5358051 ,
-0.7712355 , 0.17287332, 0.72606945, -0.15157634, 0.25406668,
-0.7806479 , -0.15099436, 0.36042407, -0.00913737, 0.08302039,
-0.4754683 , -0.47949797, -0.18461673, -0.27660465, -0.78692245,
0.6459955 , 0.0683116 ],
[ 0.689624 , 0.2917918 , 0.23204178, -0.24007678, 0.15531424,
0.06384554, 0.05830847, -0.05809493, 0.2642526 , -0.0223131 ,
0.05240817, 0.23094033, 0.34774926, -0.06030922, 0.3801004 ,
-0.45984933, 0.06846657, 0.45807868, -0.28368205, 0.19050267,
-0.60009193, 0.00462461, 0.19878045, 0.07578797, 0.16144395,
-0.32184047, -0.33067778, -0.05224532, 0.04741496, -0.5540669 ,
0.75985986, -0.11818236],
[ 0.42422694, -0.13985819, 0.15179643, -0.0994494 , 0.05480643,
0.4122148 , 0.03583536, -0.03997545, 0.0027165 , 0.5133945 ,
0.23952836, 0.02225108, 0.21865284, 0.06876859, 0.21143112,
-0.6492269 , 0.18834093, 0.43246025, -0.3369369 , 0.12497532,
-0.4983435 , -0.05300211, 0.35259238, 0.36499974, -0.01017824,
-0.51138484, -0.400999 , -0.19466041, -0.20390618, -0.6352968 ,
0.6848689 , -0.19771972]], dtype=float32), array([[-0.64182013, 0.40184742, -0.07485253, ..., 0.11808984,
-0.22370778, 0.11206529],
[-0.5201124 , 0.18107897, -0.02317163, ..., -0.2986831 ,
0.10982952, 0.18085258],
[-0.50899744, 0.2951257 , -0.13951811, ..., -0.02748931,
-0.06415796, 0.18319744],
...,
[-0.63268775, 0.62129 , -0.16334109, ..., -0.16879866,
-0.31751645, 0.15808316],
[-0.71309054, 0.2369004 , 0.10258856, ..., -0.11202749,
-0.46535045, 0.2605038 ],
[-0.71718186, 0.42549616, 0.10741499, ..., -0.22839454,
-0.1128529 , 0.39642552]], dtype=float32)]
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
WARNING: Hls::stream 'layer2_out' contains leftover data, which may result in RTL simulation hanging.
hls_prediction
[[-0.64183044 0.40183544 -0.07486248 ... 0.11807919 -0.22371769
0.11205196]
[-0.52012253 0.18106556 -0.02318382 ... -0.29869556 0.10981655
0.18084049]
[-0.50900841 0.29511356 -0.13953018 ... -0.02750206 -0.06417084
0.18318558]
...
[-0.63269997 0.62127781 -0.16335392 ... -0.1688118 -0.31753063
0.15807152]
[-0.71309853 0.23688602 0.10257721 ... -0.11204052 -0.46536255
0.26049042]
[-0.7171917 0.42548656 0.10740471 ... -0.22840595 -0.11286354
0.3964119 ]]
Traceback (most recent call last):
File "/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/test_2lin.py", line 107, in <module>
assert len(pytorch_prediction) == len(hls_prediction), "length mismatch"
AssertionError: length mismatch