Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 89b0c45

Browse files
Adjust input output sizes when any dim is None (#480)
Whenever inputs and outputs have None as dim (except for batch size), after conversion they will have 0 as the dimension. Batch size will be converted to 'N'. input(None, None) -> input(N, 0) With this fix even other dims will have variables like M1, M2. input(None, None) -> input(N, M1)
1 parent 022fb2a commit 89b0c45

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

keras2onnx/_parser_tf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def adjust_input_batch_size(var_type):
5454
return var_type
5555

5656

57+
def adjust_input_output_size(var_type, dim_variable_counter):
58+
if len(var_type.shape) > 0:
59+
for dim in range(1, len(var_type.shape)):
60+
if var_type.shape[dim] is None:
61+
dim_variable_counter += 1
62+
var_type.shape[dim] = 'M' + str(dim_variable_counter)
63+
return dim_variable_counter
64+
65+
5766
def _get_layer_name(reserved, ts_or_op):
5867
return ts_or_op.rsplit('/', 1)[0]
5968

keras2onnx/parser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from ._consts import TYPES
1616
from ._tf_ops import pass_thru_converter
1717
from ._parser_tf import (infer_variable_type, LayerInfo, is_placeholder_node,
18-
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size)
18+
tsname_to_node, on_parsing_keras_layer_v2, adjust_input_batch_size as _adjust_input_batch_size,
19+
adjust_input_output_size as _adjust_input_output_size)
1920
from ._parser_1x import (extract_inbound_nodes,
2021
list_input_tensors, list_input_mask, list_output_mask,
2122
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)
@@ -756,6 +757,7 @@ def parse_graph(topo, graph, target_opset, output_names, keras_node_dict):
756757
"""
757758
top_level = topo.declare_scope('__root')
758759

760+
dim_variable_counter = 0
759761
# Create the onnx model input name before parsing to keep ...
760762
# ... the model input names are identical to the original Keras model.
761763
for idx_ in range(len(topo.raw_model.model.inputs)):
@@ -765,6 +767,7 @@ def parse_graph(topo, graph, target_opset, output_names, keras_node_dict):
765767
idx_key = list(topo.raw_model.model.inputs.keys())[idx_]
766768
input_ts = topo.raw_model.model.inputs[idx_key]
767769
var_type = _adjust_input_batch_size(infer_variable_type(input_ts, target_opset))
770+
dim_variable_counter = _adjust_input_output_size(var_type, dim_variable_counter)
768771
str_value = input_ts.name
769772
var0 = None
770773
if hasattr(topo.raw_model.model, 'input_names'):
@@ -789,6 +792,7 @@ def parse_graph(topo, graph, target_opset, output_names, keras_node_dict):
789792
for idx_, ts_ in enumerate(output_tensors):
790793
op = top_level.declare_local_operator(TYPES.Identity)
791794
var_type = _adjust_input_batch_size(infer_variable_type(ts_, target_opset))
795+
dim_variable_counter = _adjust_input_output_size(var_type, dim_variable_counter)
792796
str_value = ts_.name
793797
use_ts_name = False
794798
if hasattr(topo.raw_model.model, 'output_names'):

0 commit comments

Comments
 (0)