Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit b545f48

Browse files
authored
Fix some tf2.x conversion bugs. (#443)
1 parent c5a69af commit b545f48

File tree

5 files changed

+92
-70
lines changed

5 files changed

+92
-70
lines changed

keras2onnx/_parse_tf.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ def _get_layer_name(reserved, ts_or_op):
5757
return ts_or_op.rsplit('/', 1)[0]
5858

5959

60+
def _get_input_mask(layer):
61+
# type: (keras.models.Layer) -> []
62+
if hasattr(layer, 'input_mask') and layer.input_mask is not None:
63+
return layer.input_mask if isinstance(layer.input_mask, (list, tuple)) else [layer.input_mask]
64+
return []
65+
66+
67+
def _get_output_mask(layer):
68+
# type: (keras.models.Layer) -> []
69+
if hasattr(layer, 'output_mask') and layer.output_mask is not None:
70+
return layer.output_mask if isinstance(layer.output_mask, (list, tuple)) else [layer.output_mask]
71+
return []
72+
73+
6074
class LayerInfo(object):
6175
def __init__(self, _ly):
6276
self.layer = _ly
@@ -102,6 +116,7 @@ def create(node, layer, outputs_map, inference_nodeset):
102116
next_itr.clear()
103117
for n_ in visited:
104118
for i_ in n_.inputs:
119+
# in layer_spec model, the layer name will be checked
105120
if fstr_list is not None and i_.op.name.find(layer_name) == -1:
106121
continue
107122
if i_.op in visited or i_.op not in inference_nodeset:
@@ -255,6 +270,10 @@ def extract_outputs_from_inbound_nodes(model):
255270
if op_name not in output_dict:
256271
output_dict[op_name] = (model, None)
257272

273+
for ts_ in _get_output_mask(model):
274+
if ts_ is not None:
275+
output_dict[ts_.op.name] = (model, model)
276+
258277
return output_dict
259278

260279

@@ -269,64 +288,43 @@ def build_layer_output_from_model(model, output_dict, input_names, output_names)
269288
return graph
270289

271290

272-
# layer.input and layer_info.inputs are different for masking layer,
273-
# we rely on layer.inputs for this case.
274-
def _get_layer_endpoints(layer_endpoints, layer_info_end_points):
275-
end_points = []
276-
end_point_candidates = layer_endpoints if isinstance(layer_endpoints, list) else [layer_endpoints]
277-
layer_info_end_points_name = [point.name for point in layer_info_end_points]
278-
for end_point_ in end_point_candidates:
279-
if end_point_.name in layer_info_end_points_name:
280-
end_points.append(end_point_)
281-
return end_points
282-
283-
284291
def on_parsing_keras_layer_v2(graph, layer_info, varset, prefix=None):
285292
layer = layer_info.layer
286293
node_list = layer_info.nodelist
287294
operator = varset.declare_local_operator(type(layer), raw_model=layer, op_name=layer.name)
288295
operator.nodelist = node_list
289296

290-
inputs = layer_info.inputs
291-
outputs = layer_info.outputs
292-
if hasattr(layer, 'input'):
293-
end_point_flag = hasattr(layer, 'input_mask') and layer.input_mask is not None
294-
end_point_flag = end_point_flag or isinstance(layer_info.layer, keras.layers.Bidirectional)
295-
if end_point_flag:
296-
inputs = _get_layer_endpoints(layer.input, layer_info.inputs)
297-
outputs = _get_layer_endpoints(layer.output, layer_info.outputs)
298-
299297
if prefix is None: # prefix is designed for the distinguish among the shared model instances.
300298
prefix = ''
301299

302-
for n_, o_ in enumerate(outputs):
303-
oname = prefix + o_.name
304-
k2o_logger().debug('output: ' + oname)
305-
o1 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
306-
operator.add_output(o1)
307-
308-
for i_ in inputs:
309-
iname = prefix + i_.name
310-
k2o_logger().debug('input : ' + iname)
311-
var_type = adjust_input_batch_size(infer_variable_type(i_, varset.target_opset))
312-
i0 = varset.get_local_variable_or_declare_one(iname, var_type)
313-
operator.add_input(i0)
314-
315-
if hasattr(layer, 'input_mask') and layer.input_mask is not None:
316-
in_mask = layer.input_mask if isinstance(layer.input_mask, (list, tuple)) else [layer.input_mask]
317-
for im_ in [m_ for m_ in in_mask if m_ is not None]:
318-
mts_name = im_.name # input mask in a shared model is not supported yet, why is it needed?
319-
k2o_logger().debug('input mask: ' + mts_name)
320-
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(im_, varset.target_opset))
321-
operator.add_input_mask(mts_var)
300+
input_masks = _get_input_mask(layer)
301+
output_masks = _get_output_mask(layer)
302+
for o_ in layer_info.outputs:
303+
if o_ not in output_masks: # the layer converter will handle output_mask by itself.
304+
oname = prefix + o_.name
305+
k2o_logger().debug('output: ' + oname)
306+
o1 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
307+
operator.add_output(o1)
322308

323-
if hasattr(layer, 'output_mask') and layer.output_mask is not None:
324-
out_mask = layer.output_mask if isinstance(layer.output_mask, (list, tuple)) else [layer.output_mask]
325-
for om_ in [m_ for m_ in out_mask if m_ is not None]:
326-
mts_name = prefix + om_.name
327-
k2o_logger().debug('output mask: ' + mts_name)
328-
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(om_, varset.target_opset))
329-
operator.add_output_mask(mts_var)
309+
for i_ in layer_info.inputs:
310+
if i_ not in input_masks: # the layer converter will handle input_mask by itself.
311+
iname = prefix + i_.name
312+
k2o_logger().debug('input : ' + iname)
313+
var_type = adjust_input_batch_size(infer_variable_type(i_, varset.target_opset))
314+
i0 = varset.get_local_variable_or_declare_one(iname, var_type)
315+
operator.add_input(i0)
316+
317+
for om_ in [m_ for m_ in output_masks if m_ is not None]:
318+
mts_name = prefix + om_.name
319+
k2o_logger().debug('output mask: ' + mts_name)
320+
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(om_, varset.target_opset))
321+
operator.add_output_mask(mts_var)
322+
323+
for im_ in [m_ for m_ in input_masks if m_ is not None]:
324+
mts_name = im_.name # input mask in a shared model is not supported yet, why is it needed?
325+
k2o_logger().debug('input mask: ' + mts_name)
326+
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(im_, varset.target_opset))
327+
operator.add_input_mask(mts_var)
330328

331329
if hasattr(layer, 'mask_value') and layer.mask_value is not None:
332330
operator.mask_value = layer.mask_value

keras2onnx/common/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
from .intop import Operator
1010
from .interim import OnnxObjectContainer, InterimContext, Variable
1111

12+
1213
# keras2onnx common code has been refactored into onnxconverter-common.
1314

1415
def name_func(scope, operator):
1516
"""Returns a function that can generate unique names for an operator based on the
1617
scope.
1718
"""
19+
1820
def _name_func(name):
1921
return scope.get_unique_operator_name(operator.full_name + '_' + name)
22+
2023
return _name_func

keras2onnx/ke2onnx/lstm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def extract_params(op, hidden_size, input_size):
4949

5050
return W_x, W_h, b
5151

52+
5253
def build_parameters(scope, operator, container, bidirectional=False):
5354
"""Returns the parameter initialization values after extracting them from the LSTM layer.
5455
"""
@@ -106,9 +107,9 @@ def build_parameters(scope, operator, container, bidirectional=False):
106107
tensor_b = _name('B')
107108
container.add_initializer(tensor_b, TensorProto.FLOAT, B_shape, B)
108109

109-
110110
return tensor_w, tensor_r, tensor_b
111111

112+
112113
def build_initial_states(scope, operator, container, bidirectional=False):
113114
"""Builds the initial hidden and cell states for the LSTM layer.
114115
"""
@@ -118,8 +119,8 @@ def build_initial_states(scope, operator, container, bidirectional=False):
118119

119120
# Determine if the cell states are set
120121
has_c = (
121-
(len(operator.inputs) > 1 and not bidirectional) or
122-
(len(operator.inputs) > 3 and bidirectional)
122+
(len(operator.inputs) > 1 and not bidirectional) or
123+
(len(operator.inputs) > 3 and bidirectional)
123124
)
124125
if not has_c:
125126
return initial_h, ''
@@ -183,6 +184,7 @@ def build_attributes(scope, operator, container, bidirectional=False):
183184
]))
184185
return attrs
185186

187+
186188
def build_output(scope, operator, container, output_names, bidirectional=False):
187189
"""Builds the output operators for the LSTM layer.
188190
"""

keras2onnx/parser.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
list_input_tensors, list_input_mask, list_output_mask,
2121
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)
2222

23+
2324
def _find_node(nodes, name):
2425
try:
2526
opname = tsname_to_node(name)
2627
return next(n_ for n_ in nodes if n_.name == opname)
2728
except StopIteration:
2829
return None
2930

31+
3032
def _locate_inputs_by_node(node_list, varset):
3133
inputs = {}
3234
for n_ in node_list:
@@ -480,7 +482,7 @@ def _advance_by_input(cur_node, layer_nodes, subgraph, inputs, graph_inputs, q_o
480482
for input_ in cur_node.inputs:
481483
predecessor = input_.op
482484
if is_placeholder_node(predecessor):
483-
# mysteriously, some bn layer create a placeholder node 'scale' in tf2.x.
485+
# tf.keras BN layer sometimes create a placeholder node 'scale' in tf2.x.
484486
# Given bn layer will be converted in a whole layer, it's fine to just filter this node out.
485487
if not re.match(r"batch_normalization_\d+\/scale$", predecessor.name):
486488
inputs.add(predecessor)
@@ -655,7 +657,6 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
655657
nodelist = []
656658
layer_inputs = _visit_nodelist(layer_info.nodelist, graph_inputs, None, keras_node_dict, node, nodelist,
657659
q_overall, visited)
658-
659660
sorted_inputs = _sorted_inputs(layer_info.nodelist, layer_info.outputs, layer_inputs)
660661
for input_ in sorted_inputs:
661662
layer_info.inputs.extend(input_.outputs)
@@ -691,15 +692,18 @@ def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_nam
691692
q_overall.put_nowait(n_)
692693

693694
visited = set() # since the output could be shared among the successor nodes.
694-
inference_nodeset = _build_inference_nodeset(graph, model_outputs)
695+
# Some complicated layer may have some nodes which cannot be visited from the graph output...
696+
# ..., so the layer outputs are added into visit graph to avoid missing nodes.
697+
layer_outputs = [graph.get_operation_by_name(nm_) for nm_ in keras_node_dict]
698+
inference_nodeset = _build_inference_nodeset(graph, model_outputs + layer_outputs)
695699
while not q_overall.empty():
696700
node = q_overall.get_nowait()
697701
if node in input_nodes or node in visited or node not in inference_nodeset:
698702
continue
699703

700704
layer_info, model_ = _parse_nodes_v2(graph, inference_nodeset, input_nodes, keras_node_dict, node,
701-
varset, visited, q_overall)
702-
if not layer_info: # already processed by the parse_nodes_XX
705+
varset, visited, q_overall)
706+
if not layer_info: # already processed by the _parse_nodes_v2
703707
continue
704708

705709
k2o_logger().debug('Processing a keras layer - (%s: %s)' % (layer_info.layer.name, type(layer_info.layer)) if

0 commit comments

Comments
 (0)