Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 3a6e4ad

Browse files
cjermainjiafatom
andauthored
Custom Masking value (#389)
* Adding broken test for using custom mask_value * Allowing custom mask_value to be used * Fixing dtypes in mask_value Equal operator * Adding meaningful exception when the masking value is not specified Co-authored-by: David Fan <[email protected]>
1 parent 194b4f9 commit 3a6e4ad

File tree

4 files changed

+23
-2
lines changed

4 files changed

+23
-2
lines changed

keras2onnx/_parser_1x.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def on_parsing_keras_layer(graph, node_list, layer, kenode, model, varset, prefi
111111
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(om_, varset.target_opset))
112112
operator.add_output_mask(mts_var)
113113

114+
if hasattr(layer, 'mask_value') and layer.mask_value is not None:
115+
operator.mask_value = layer.mask_value
116+
114117
cvt = get_converter(operator.type)
115118
if cvt is not None and hasattr(cvt, 'shape_infer'):
116119
operator.shape_infer = cvt.shape_infer

keras2onnx/common/intop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, onnx_name, scope, type, raw_operator, target_opset):
2626
self.input_masks = []
2727
self.outputs = []
2828
self.output_masks = []
29+
self.mask_value = None
2930
self.nodelist = None
3031
self.is_evaluated = None
3132
self.target_opset = target_opset

keras2onnx/ke2onnx/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ def convert_keras_flatten(scope, operator, container):
6868

6969

7070
def _apply_not_equal(oopb, target_opset, operator):
71+
if operator.mask_value is None:
72+
raise ValueError("Masking value was not properly parsed for layer '{}'".format(operator.full_name))
7173
if target_opset >= 11:
72-
equal_out = oopb.add_node('Equal', [operator.inputs[0].full_name, np.array([0], dtype='float32')],
74+
equal_out = oopb.add_node('Equal', [operator.inputs[0].full_name, np.array([operator.mask_value], dtype='float32')],
7375
operator.full_name + 'mask')
7476
not_o = oopb.add_node('Not', equal_out,
7577
name=operator.full_name + '_not')
@@ -78,7 +80,7 @@ def _apply_not_equal(oopb, target_opset, operator):
7880
"the masking layer result may be incorrect if the model input is in range (0, 1.0).")
7981
equal_input_0 = oopb.add_node('Cast', [operator.inputs[0].full_name],
8082
operator.full_name + '_input_cast', to=6)
81-
equal_out = oopb.add_node('Equal', [equal_input_0, np.array([0], dtype='int32')],
83+
equal_out = oopb.add_node('Equal', [equal_input_0, np.array([operator.mask_value], dtype='int32')],
8284
operator.full_name + 'mask')
8385
not_o = oopb.add_node('Not', equal_out,
8486
name=operator.full_name + '_not')

tests/test_layers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,21 @@ def test_masking(self):
17881788
expected = model.predict(x)
17891789
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
17901790

1791+
@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
1792+
def test_masking_value(self):
1793+
timesteps, features = (3, 5)
1794+
mask_value = 5.
1795+
model = Sequential([
1796+
keras.layers.Masking(mask_value=mask_value, input_shape=(timesteps, features)),
1797+
LSTM(8, return_state=False, return_sequences=False)
1798+
])
1799+
1800+
onnx_model = keras2onnx.convert_keras(model, model.name)
1801+
x = np.random.uniform(100, 999, size=(2, 3, 5)).astype(np.float32)
1802+
x[1, :, :] = mask_value
1803+
expected = model.predict(x)
1804+
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
1805+
17911806
@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
17921807
def test_masking_custom(self):
17931808
class MyPoolingMask(keras.layers.Layer):

0 commit comments

Comments
 (0)