Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit cad12fc

Browse files
authored
Handle TimeDistributed layer for tf2 and tf.keras (#420)
1 parent aad8ed8 commit cad12fc

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

keras2onnx/parser.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,10 @@ def travel(node):
626626

627627

628628
def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, node, varset, visited, q_overall):
629-
layer_key = None
629+
layer_key, model_ = (None, None)
630630
current_layer_outputs = {}
631631
if node.name in keras_node_dict:
632-
layer_key = keras_node_dict[node.name][0]
632+
layer_key, model_ = keras_node_dict[node.name]
633633
else:
634634
ts_out = node.outputs[0]
635635
kh_ = getattr(ts_out, '_keras_history', None)
@@ -648,7 +648,7 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
648648
_create_identity(ts_.op.inputs[0], ts_, varset)
649649
visited.add(ts_.op)
650650
_advance_by_input(ts_.op, [ts_.op], list(), set(), graph_inputs, q_overall)
651-
return None
651+
return None, model_
652652
else:
653653
layer_info = LayerInfo.create(node, layer_key,
654654
{**keras_node_dict, **current_layer_outputs}, inference_nodeset)
@@ -662,7 +662,7 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
662662
layer_info.inputs.extend(input_.outputs)
663663

664664
layer_info.nodelist = [n_ for n_ in layer_info.nodelist if not is_placeholder_node(n_)]
665-
return layer_info
665+
return layer_info, model_
666666

667667

668668
def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_names):
@@ -698,14 +698,16 @@ def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_nam
698698
if node in input_nodes or node in visited:
699699
continue
700700

701-
layer_info = _parse_nodes_v2(graph, inference_nodeset, input_nodes, keras_node_dict, node,
701+
layer_info, model_ = _parse_nodes_v2(graph, inference_nodeset, input_nodes, keras_node_dict, node,
702702
varset, visited, q_overall)
703703
if not layer_info: # already processed by the parse_nodes_XX
704704
continue
705705

706706
k2o_logger().debug('Processing a keras layer - (%s: %s)' % (layer_info.layer.name, type(layer_info.layer)) if
707707
layer_info.layer else (layer_info.nodelist[0].name, "Custom_Layer"))
708-
if layer_info.layer and get_converter(type(layer_info.layer)):
708+
if layer_info.layer and isinstance(layer_info.layer, keras.layers.TimeDistributed):
709+
_on_parsing_time_distributed_layer(graph, layer_info.nodelist, layer_info.layer, model_, varset)
710+
elif layer_info.layer and get_converter(type(layer_info.layer)):
709711
on_parsing_keras_layer_v2(graph, layer_info, varset)
710712
else:
711713
_on_parsing_tf_nodes(graph, layer_info.nodelist, varset, topology.debug_mode)

tests/test_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,6 @@ def compute_output_shape(self, input_shape):
19421942
expected = model.predict(x)
19431943
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
19441944

1945-
@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
19461945
def test_timedistributed(self):
19471946
keras_model = keras.Sequential()
19481947
keras_model.add(TimeDistributed(Dense(8), input_shape=(10, 16)))

0 commit comments

Comments
 (0)