@@ -626,10 +626,10 @@ def travel(node):
626
626
627
627
628
628
def _parse_nodes_v2 (graph , inference_nodeset , graph_inputs , keras_node_dict , node , varset , visited , q_overall ):
629
- layer_key = None
629
+ layer_key , model_ = ( None , None )
630
630
current_layer_outputs = {}
631
631
if node .name in keras_node_dict :
632
- layer_key = keras_node_dict [node .name ][ 0 ]
632
+ layer_key , model_ = keras_node_dict [node .name ]
633
633
else :
634
634
ts_out = node .outputs [0 ]
635
635
kh_ = getattr (ts_out , '_keras_history' , None )
@@ -648,7 +648,7 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
648
648
_create_identity (ts_ .op .inputs [0 ], ts_ , varset )
649
649
visited .add (ts_ .op )
650
650
_advance_by_input (ts_ .op , [ts_ .op ], list (), set (), graph_inputs , q_overall )
651
- return None
651
+ return None , model_
652
652
else :
653
653
layer_info = LayerInfo .create (node , layer_key ,
654
654
{** keras_node_dict , ** current_layer_outputs }, inference_nodeset )
@@ -662,7 +662,7 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
662
662
layer_info .inputs .extend (input_ .outputs )
663
663
664
664
layer_info .nodelist = [n_ for n_ in layer_info .nodelist if not is_placeholder_node (n_ )]
665
- return layer_info
665
+ return layer_info , model_
666
666
667
667
668
668
def _parse_graph_core_v2 (graph , keras_node_dict , topology , top_scope , output_names ):
@@ -698,14 +698,16 @@ def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_nam
698
698
if node in input_nodes or node in visited :
699
699
continue
700
700
701
- layer_info = _parse_nodes_v2 (graph , inference_nodeset , input_nodes , keras_node_dict , node ,
701
+ layer_info , model_ = _parse_nodes_v2 (graph , inference_nodeset , input_nodes , keras_node_dict , node ,
702
702
varset , visited , q_overall )
703
703
if not layer_info : # already processed by the parse_nodes_XX
704
704
continue
705
705
706
706
k2o_logger ().debug ('Processing a keras layer - (%s: %s)' % (layer_info .layer .name , type (layer_info .layer )) if
707
707
layer_info .layer else (layer_info .nodelist [0 ].name , "Custom_Layer" ))
708
- if layer_info .layer and get_converter (type (layer_info .layer )):
708
+ if layer_info .layer and isinstance (layer_info .layer , keras .layers .TimeDistributed ):
709
+ _on_parsing_time_distributed_layer (graph , layer_info .nodelist , layer_info .layer , model_ , varset )
710
+ elif layer_info .layer and get_converter (type (layer_info .layer )):
709
711
on_parsing_keras_layer_v2 (graph , layer_info , varset )
710
712
else :
711
713
_on_parsing_tf_nodes (graph , layer_info .nodelist , varset , topology .debug_mode )
0 commit comments