Skip to content

Commit 3b7e595

Browse files
Support multiple outputs in pytorch parser (#1151)
* support multiple outputs in pytorch parser * pre-commit * [pre-commit.ci] auto fixes from pre-commit hooks * fix missing input names in squeeze layers --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 90ada94 commit 3b7e595

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

hls4ml/converters/pytorch/reshape.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def parse_squeeze_layer(operation, layer_name, input_names, input_shapes, node,
3838
layer = {}
3939
layer['class_name'] = 'Reshape'
4040
layer['name'] = layer_name
41+
layer['inputs'] = input_names
4142

4243
if len(node.args) > 1 or len(node.kwargs) > 0: # 'dim' argument is specified
4344
output_shape = [i for i in input_shapes[0]]

hls4ml/converters/pytorch_to_hls.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def parse_pytorch_model(config, verbose=True):
151151
inputs_map = {}
152152

153153
input_layers = []
154+
output_layers = []
154155

155156
# Output shape tracking
156157
output_shapes = {}
@@ -399,12 +400,23 @@ def parse_pytorch_model(config, verbose=True):
399400
if len(input_layers) == 0:
400401
input_layers = None
401402

402-
return layer_list, input_layers
403+
for layer in layer_list:
404+
if layer['class_name'] == 'InputLayer':
405+
continue
406+
is_input = False
407+
for lay in layer_list:
408+
if 'inputs' not in lay.keys():
409+
continue
410+
if layer['name'] in lay['inputs']:
411+
is_input = True
412+
if not is_input:
413+
output_layers.append(layer['name'])
414+
return layer_list, input_layers, output_layers
403415

404416

405417
@requires('_torch')
406418
def pytorch_to_hls(config):
407-
layer_list, input_layers = parse_pytorch_model(config)
419+
layer_list, input_layers, output_layers = parse_pytorch_model(config)
408420
print('Creating HLS model')
409-
hls_model = ModelGraph(config, layer_list, inputs=input_layers)
421+
hls_model = ModelGraph(config, layer_list, inputs=input_layers, outputs=output_layers)
410422
return hls_model

hls4ml/utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def config_from_pytorch_model(
368368
(
369369
layer_list,
370370
_,
371+
_,
371372
) = parse_pytorch_model(config, verbose=False)
372373

373374
def make_layer_config(layer):

0 commit comments

Comments
 (0)