Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Fix bidirectional output_seq=False case #85

Merged
merged 8 commits into from
May 8, 2019
3 changes: 1 addition & 2 deletions keras2onnx/ke2onnx/bidirectional.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ def convert_bidirectional(scope, operator, container):
container.add_node('Squeeze', backward_y_name, operator.outputs[1].full_name,
name=scope.get_unique_variable_name('Squeeze'), axes=[2])
else:
perm = [1, 0, 2]
if merge_concat:
# In this case, only one Keras output with shape (N, 2 * C') should be produced

# Transpose ONNX LSTM Y_h with shape (D, N, C') into (N, D, C')
transposed_h_name = scope.get_unique_variable_name(operator.full_name + '_Y_h_transposed')
perm = [1, 0, 2] if container.target_opset <= 5 else [2, 0, 1, 3]
apply_transpose(scope, lstm_h_name, transposed_h_name, container, perm=perm)

# Maintain backwards opset compatibility for 'Flatten'
Expand All @@ -265,7 +265,6 @@ def convert_bidirectional(scope, operator, container):

# Transpose ONNX LSTM Y_h with shape (D, N, C') into (N, D, C')
transposed_h_name = scope.get_unique_variable_name(operator.full_name + '_Y_h_transposed')
perm = [1, 0, 2] if container.target_opset <= 5 else [2, 0, 1, 3]
apply_transpose(scope, lstm_h_name, transposed_h_name, container, perm=perm)

# Split the transposed Y with shape (T, N, D, C') into (T, N, 1, C') and (T, N, 1, C')
Expand Down
49 changes: 30 additions & 19 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,8 @@ def test_batch_normalization(self):
self._batch_norm_helper(data, 'zeros', 'zeros', False, True, 1)

def test_simpleRNN(self):
from keras.layers import Input, Dense, SimpleRNN
inputs1 = Input(shape=(3, 1))
cls = SimpleRNN(2, return_state=False, return_sequences=True)
inputs1 = keras.Input(shape=(3, 1))
cls = keras.layers.SimpleRNN(2, return_state=False, return_sequences=True)
oname = cls(inputs1) # , initial_state=t0)
model = keras.Model(inputs=inputs1, outputs=[oname])
onnx_model = keras2onnx.convert_keras(model, model.name)
Expand All @@ -544,10 +543,10 @@ def test_simpleRNN(self):
self.assertTrue(self.run_onnx_runtime(onnx_model.graph.name, onnx_model, data, expected))

# with initial state
inputs2 = Input(shape=(1, 2))
state = Input(shape=(5,))
hidden_1 = SimpleRNN(5, activation='relu', return_sequences=True)(inputs2, initial_state=[state])
output = Dense(2, activation='sigmoid')(hidden_1)
inputs2 = keras.Input(shape=(1, 2))
state = keras.Input(shape=(5,))
hidden_1 = keras.layers.SimpleRNN(5, activation='relu', return_sequences=True)(inputs2, initial_state=[state])
output = keras.layers.Dense(2, activation='sigmoid')(hidden_1)
keras_model = keras.Model(inputs=[inputs2, state], outputs=output)
onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name, debug_mode=True)

Expand All @@ -558,11 +557,11 @@ def test_simpleRNN(self):
self.assertTrue(self.run_onnx_runtime(onnx_model.graph.name, onnx_model, [x, s], expected))

# with initial state and output state
input = Input(shape=(1, 2))
state_in = Input(shape=(10,))
hidden_1, state_out = SimpleRNN(10, activation='relu', return_sequences=True, return_state=True)(input,
input = keras.Input(shape=(1, 2))
state_in = keras.Input(shape=(10,))
hidden_1, state_out = keras.layers.SimpleRNN(10, activation='relu', return_sequences=True, return_state=True)(input,
initial_state=[state_in])
output = Dense(2, activation='linear')(hidden_1)
output = keras.layers.Dense(2, activation='linear')(hidden_1)
keras_model = keras.Model(inputs=[input, state_in], outputs=[output, state_out])
onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name)

Expand Down Expand Up @@ -625,6 +624,21 @@ def test_LSTM_reshape(self):
expected = model.predict(data)
self.assertTrue(self.run_onnx_runtime('tf_lstm', onnx_model, data, expected))

def test_Bidirectional(self):
input_dim = 10
sequence_len = 5
model = keras.Sequential()
model.add(keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences=False),
input_shape=(5, 10)))
model.add(keras.layers.Dense(5))
model.add(keras.layers.Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

onnx_model = keras2onnx.convert_keras(model, 'test')
data = np.random.rand(input_dim, sequence_len).astype(np.float32).reshape((1, sequence_len, input_dim))
expected = model.predict(data)
self.assertTrue(self.run_onnx_runtime('bidirectional', onnx_model, data, expected))

def test_separable_convolution(self):
N, C, H, W = 2, 3, 5, 5
x = np.random.rand(N, H, W, C).astype(np.float32, copy=False)
Expand Down Expand Up @@ -709,21 +723,18 @@ def test_recursive_and_shared_model(self):
self.assertTrue(self.run_onnx_runtime('recursive_and_shared', onnx_model, x, expected))

def test_timedistributed(self):
from keras import Sequential
from keras.layers import TimeDistributed, Dense, Conv2D

keras_model = Sequential()
keras_model.add(TimeDistributed(Dense(8), input_shape=(10, 16)))
keras_model = keras.Sequential()
keras_model.add(keras.layers.TimeDistributed(keras.layers.Dense(8), input_shape=(10, 16)))
# keras_model.output_shape == (None, 10, 8)
onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name, debug_mode=True)
x = np.random.rand(32, 10, 16).astype(np.float32)
expected = keras_model.predict(x)
self.assertTrue(self.run_onnx_runtime(onnx_model.graph.name, onnx_model, x, expected))

keras_model = Sequential()
keras_model = keras.Sequential()
N, D, W, H, C = 5, 10, 15, 15, 3
keras_model.add(TimeDistributed(Conv2D(64, (3, 3)),
input_shape=(D, W, H, C)))
keras_model.add(keras.layers.TimeDistributed(keras.layers.Conv2D(64, (3, 3)),
input_shape=(D, W, H, C)))
onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name, debug_mode=True)
x = np.random.rand(N, D, W, H, C).astype(np.float32)
expected = keras_model.predict(x)
Expand Down