@@ -57,6 +57,20 @@ def _get_layer_name(reserved, ts_or_op):
57
57
return ts_or_op .rsplit ('/' , 1 )[0 ]
58
58
59
59
60
+ def _get_input_mask (layer ):
61
+ # type: (keras.models.Layer) -> []
62
+ if hasattr (layer , 'input_mask' ) and layer .input_mask is not None :
63
+ return layer .input_mask if isinstance (layer .input_mask , (list , tuple )) else [layer .input_mask ]
64
+ return []
65
+
66
+
67
+ def _get_output_mask (layer ):
68
+ # type: (keras.models.Layer) -> []
69
+ if hasattr (layer , 'output_mask' ) and layer .output_mask is not None :
70
+ return layer .output_mask if isinstance (layer .output_mask , (list , tuple )) else [layer .output_mask ]
71
+ return []
72
+
73
+
60
74
class LayerInfo (object ):
61
75
def __init__ (self , _ly ):
62
76
self .layer = _ly
@@ -102,6 +116,7 @@ def create(node, layer, outputs_map, inference_nodeset):
102
116
next_itr .clear ()
103
117
for n_ in visited :
104
118
for i_ in n_ .inputs :
119
+ # in layer_spec model, the layer name will be checked
105
120
if fstr_list is not None and i_ .op .name .find (layer_name ) == - 1 :
106
121
continue
107
122
if i_ .op in visited or i_ .op not in inference_nodeset :
@@ -255,6 +270,10 @@ def extract_outputs_from_inbound_nodes(model):
255
270
if op_name not in output_dict :
256
271
output_dict [op_name ] = (model , None )
257
272
273
+ for ts_ in _get_output_mask (model ):
274
+ if ts_ is not None :
275
+ output_dict [ts_ .op .name ] = (model , model )
276
+
258
277
return output_dict
259
278
260
279
@@ -269,64 +288,43 @@ def build_layer_output_from_model(model, output_dict, input_names, output_names)
269
288
return graph
270
289
271
290
272
- # layer.input and layer_info.inputs are different for masking layer,
273
- # we rely on layer.inputs for this case.
274
- def _get_layer_endpoints (layer_endpoints , layer_info_end_points ):
275
- end_points = []
276
- end_point_candidates = layer_endpoints if isinstance (layer_endpoints , list ) else [layer_endpoints ]
277
- layer_info_end_points_name = [point .name for point in layer_info_end_points ]
278
- for end_point_ in end_point_candidates :
279
- if end_point_ .name in layer_info_end_points_name :
280
- end_points .append (end_point_ )
281
- return end_points
282
-
283
-
284
291
def on_parsing_keras_layer_v2 (graph , layer_info , varset , prefix = None ):
285
292
layer = layer_info .layer
286
293
node_list = layer_info .nodelist
287
294
operator = varset .declare_local_operator (type (layer ), raw_model = layer , op_name = layer .name )
288
295
operator .nodelist = node_list
289
296
290
- inputs = layer_info .inputs
291
- outputs = layer_info .outputs
292
- if hasattr (layer , 'input' ):
293
- end_point_flag = hasattr (layer , 'input_mask' ) and layer .input_mask is not None
294
- end_point_flag = end_point_flag or isinstance (layer_info .layer , keras .layers .Bidirectional )
295
- if end_point_flag :
296
- inputs = _get_layer_endpoints (layer .input , layer_info .inputs )
297
- outputs = _get_layer_endpoints (layer .output , layer_info .outputs )
298
-
299
297
if prefix is None : # prefix is designed for the distinguish among the shared model instances.
300
298
prefix = ''
301
299
302
- for n_ , o_ in enumerate (outputs ):
303
- oname = prefix + o_ .name
304
- k2o_logger ().debug ('output: ' + oname )
305
- o1 = varset .get_local_variable_or_declare_one (oname , infer_variable_type (o_ , varset .target_opset ))
306
- operator .add_output (o1 )
307
-
308
- for i_ in inputs :
309
- iname = prefix + i_ .name
310
- k2o_logger ().debug ('input : ' + iname )
311
- var_type = adjust_input_batch_size (infer_variable_type (i_ , varset .target_opset ))
312
- i0 = varset .get_local_variable_or_declare_one (iname , var_type )
313
- operator .add_input (i0 )
314
-
315
- if hasattr (layer , 'input_mask' ) and layer .input_mask is not None :
316
- in_mask = layer .input_mask if isinstance (layer .input_mask , (list , tuple )) else [layer .input_mask ]
317
- for im_ in [m_ for m_ in in_mask if m_ is not None ]:
318
- mts_name = im_ .name # input mask in a shared model is not supported yet, why is it needed?
319
- k2o_logger ().debug ('input mask: ' + mts_name )
320
- mts_var = varset .get_local_variable_or_declare_one (mts_name , infer_variable_type (im_ , varset .target_opset ))
321
- operator .add_input_mask (mts_var )
300
+ input_masks = _get_input_mask (layer )
301
+ output_masks = _get_output_mask (layer )
302
+ for o_ in layer_info .outputs :
303
+ if o_ not in output_masks : # the layer converter will handle output_mask by itself.
304
+ oname = prefix + o_ .name
305
+ k2o_logger ().debug ('output: ' + oname )
306
+ o1 = varset .get_local_variable_or_declare_one (oname , infer_variable_type (o_ , varset .target_opset ))
307
+ operator .add_output (o1 )
322
308
323
- if hasattr (layer , 'output_mask' ) and layer .output_mask is not None :
324
- out_mask = layer .output_mask if isinstance (layer .output_mask , (list , tuple )) else [layer .output_mask ]
325
- for om_ in [m_ for m_ in out_mask if m_ is not None ]:
326
- mts_name = prefix + om_ .name
327
- k2o_logger ().debug ('output mask: ' + mts_name )
328
- mts_var = varset .get_local_variable_or_declare_one (mts_name , infer_variable_type (om_ , varset .target_opset ))
329
- operator .add_output_mask (mts_var )
309
+ for i_ in layer_info .inputs :
310
+ if i_ not in input_masks : # the layer converter will handle input_mask by itself.
311
+ iname = prefix + i_ .name
312
+ k2o_logger ().debug ('input : ' + iname )
313
+ var_type = adjust_input_batch_size (infer_variable_type (i_ , varset .target_opset ))
314
+ i0 = varset .get_local_variable_or_declare_one (iname , var_type )
315
+ operator .add_input (i0 )
316
+
317
+ for om_ in [m_ for m_ in output_masks if m_ is not None ]:
318
+ mts_name = prefix + om_ .name
319
+ k2o_logger ().debug ('output mask: ' + mts_name )
320
+ mts_var = varset .get_local_variable_or_declare_one (mts_name , infer_variable_type (om_ , varset .target_opset ))
321
+ operator .add_output_mask (mts_var )
322
+
323
+ for im_ in [m_ for m_ in input_masks if m_ is not None ]:
324
+ mts_name = im_ .name # input mask in a shared model is not supported yet, why is it needed?
325
+ k2o_logger ().debug ('input mask: ' + mts_name )
326
+ mts_var = varset .get_local_variable_or_declare_one (mts_name , infer_variable_type (im_ , varset .target_opset ))
327
+ operator .add_input_mask (mts_var )
330
328
331
329
if hasattr (layer , 'mask_value' ) and layer .mask_value is not None :
332
330
operator .mask_value = layer .mask_value
0 commit comments