Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit adbc41c

Browse files
committed
Adding broken test for masked-bidirectional RNNs
1 parent 4a22019 commit adbc41c

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/test_layers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,36 @@ def test_masking_bias(self):
18261826
onnx_model = keras2onnx.convert_keras(model, model.name)
18271827
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
18281828

1829+
@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
1830+
def test_masking_bias_bidirectional(self):
1831+
for rnn_class in [LSTM, GRU, SimpleRNN]:
1832+
1833+
timesteps, features = (3, 5)
1834+
model = Sequential([
1835+
keras.layers.Masking(mask_value=0., input_shape=(timesteps, features)),
1836+
Bidirectional(rnn_class(8, return_state=False, return_sequences=False, use_bias=True, name='rnn'))
1837+
])
1838+
1839+
x = np.random.uniform(100, 999, size=(2, 3, 5)).astype(np.float32)
1840+
# Fill one of the entries with all zeros except the first timestep
1841+
x[1, 1:, :] = 0
1842+
1843+
# Test with the default bias
1844+
expected = model.predict(x)
1845+
onnx_model = keras2onnx.convert_keras(model, model.name)
1846+
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
1847+
1848+
# Set bias values to random floats
1849+
rnn_layer = model.get_layer('rnn')
1850+
weights = rnn_layer.get_weights()
1851+
weights[2] = np.random.uniform(size=weights[2].shape)
1852+
rnn_layer.set_weights(weights)
1853+
1854+
# Test with random bias
1855+
expected = model.predict(x)
1856+
onnx_model = keras2onnx.convert_keras(model, model.name)
1857+
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected, self.model_files))
1858+
18291859
@unittest.skipIf(is_tf2 and is_tf_keras, 'TODO')
18301860
def test_masking_value(self):
18311861
timesteps, features = (3, 5)

0 commit comments

Comments
 (0)