Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Masking RNN with zeros input #386

Merged
merged 11 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion keras2onnx/ke2onnx/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def convert_keras_gru(scope, operator, container):
gru_input_names.append('')

# sequence lens
gru_input_names.append('')
uses_masking_layer = len(operator.input_masks) == 1
if uses_masking_layer:
# Mask using sequence_lens input
sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens')
gru_input_names.append(sequence_lengths)
else:
gru_input_names.append('')
# inital_h
if len(operator.inputs) == 1:
gru_input_names.append('')
Expand Down Expand Up @@ -88,6 +94,12 @@ def convert_keras_gru(scope, operator, container):
gru_h_name = scope.get_unique_variable_name('gru_h')
gru_output_names = [gru_y_name, gru_h_name]
oopb = OnnxOperatorBuilder(container, scope)

if uses_masking_layer:
mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast')
oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum')


oopb.apply_op_with_output('apply_gru',
gru_input_names,
gru_output_names,
Expand Down
13 changes: 12 additions & 1 deletion keras2onnx/ke2onnx/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ def convert_keras_lstm(scope, operator, container):
lstm_input_names.append('')

# sequence_lens
lstm_input_names.append('')
uses_masking_layer = len(operator.input_masks) == 1
if uses_masking_layer:
# Mask using sequence_lens input
sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens')
lstm_input_names.append(sequence_lengths)
else:
lstm_input_names.append('')
# inital_h
if len(operator.inputs) <= 1:
lstm_input_names.append('')
Expand Down Expand Up @@ -149,6 +155,11 @@ def convert_keras_lstm(scope, operator, container):
lstm_output_names.append(lstm_c_name)

oopb = OnnxOperatorBuilder(container, scope)

if uses_masking_layer:
mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast')
oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum')

oopb.apply_op_with_output('apply_lstm',
lstm_input_names,
lstm_output_names,
Expand Down
2 changes: 1 addition & 1 deletion keras2onnx/ke2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _apply_not_equal(oopb, target_opset, operator):
k2o_logger().warning("On converting a model with opset < 11, " +
"the masking layer result may be incorrect if the model input is in range (0, 1.0).")
equal_input_0 = oopb.add_node('Cast', [operator.inputs[0].full_name],
operator.full_name + '_input_cast', to=6)
operator.full_name + '_input_cast', to=oopb.int32)
equal_out = oopb.add_node('Equal', [equal_input_0, np.array([operator.mask_value], dtype='int32')],
operator.full_name + 'mask')
not_o = oopb.add_node('Not', equal_out,
Expand Down
13 changes: 12 additions & 1 deletion keras2onnx/ke2onnx/simplernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ def convert_keras_simple_rnn(scope, operator, container):
rnn_input_names.append('')

# sequence_lens is not able to be converted from input_length
rnn_input_names.append('')
uses_masking_layer = len(operator.input_masks) == 1
if uses_masking_layer:
# Mask using sequence_lens input
sequence_lengths = scope.get_unique_variable_name(operator.full_name + '_seq_lens')
rnn_input_names.append(sequence_lengths)
else:
rnn_input_names.append('')
# inital_h: none
if len(operator.inputs) == 1:
rnn_input_names.append('')
Expand Down Expand Up @@ -77,6 +83,11 @@ def convert_keras_simple_rnn(scope, operator, container):
rnn_output_names.append(rnn_y_name)
rnn_output_names.append(rnn_h_name)
oopb = OnnxOperatorBuilder(container, scope)

if uses_masking_layer:
mask_cast = oopb.apply_cast(operator.input_masks[0].full_name, to=oopb.int32, name=operator.full_name + '_mask_cast')
oopb.add_node_with_output('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1], name=operator.full_name + '_mask_sum')

oopb.apply_op_with_output('apply_rnn',
rnn_input_names,
rnn_output_names,
Expand Down
30 changes: 30 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,36 @@ def test_masking(self):
expected = model.predict(x)
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))

@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
def test_masking_bias(self):
for rnn_class in [LSTM, GRU, SimpleRNN]:

timesteps, features = (3, 5)
model = Sequential([
keras.layers.Masking(mask_value=0., input_shape=(timesteps, features)),
rnn_class(8, return_state=False, return_sequences=False, use_bias=True, name='rnn')
])

x = np.random.uniform(100, 999, size=(2, 3, 5)).astype(np.float32)
# Fill one of the entries with all zeros except the first timestep
x[1, 1:, :] = 0

# Test with the default bias
expected = model.predict(x)
onnx_model = keras2onnx.convert_keras(model, model.name)
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))

# Set bias values to random floats
rnn_layer = model.get_layer('rnn')
weights = rnn_layer.get_weights()
weights[2] = np.random.uniform(size=weights[2].shape)
rnn_layer.set_weights(weights)

# Test with random bias
expected = model.predict(x)
onnx_model = keras2onnx.convert_keras(model, model.name)
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))

@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
def test_masking_value(self):
timesteps, features = (3, 5)
Expand Down