Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 3ee0415

Browse files
authored
Adding support for recurrent_v2 layers for TF 2.0 compatibility ahead of TF 2.0 (#610)
Signed-off-by: Colin Jermain <[email protected]>
1 parent e3ebae8 commit 3ee0415

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

keras2onnx/ke2onnx/bidirectional.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import collections
77
import numbers
88
from ..common import cvtfunc
9-
from ..proto import keras
9+
from ..proto import keras, is_tf_keras, is_tensorflow_later_than
1010
from . import simplernn, gru, lstm
1111

12+
LSTM_CLASSES = {keras.layers.LSTM}
13+
GRU_CLASSES = {keras.layers.GRU}
14+
1215

1316
def _calculate_keras_bidirectional_output_shapes(operator):
1417
op = operator.raw_operator
@@ -27,9 +30,15 @@ def convert_bidirectional(scope, operator, container):
2730
op_type = type(operator.raw_operator.forward_layer)
2831
bidirectional = True
2932

30-
if op_type == keras.layers.LSTM:
33+
if is_tf_keras and is_tensorflow_later_than("1.14.0"):
34+
# Add the TF v2 compatability layers (available after TF 1.14)
35+
from tensorflow.python.keras.layers import recurrent_v2
36+
LSTM_CLASSES.add(recurrent_v2.LSTM)
37+
GRU_CLASSES.add(recurrent_v2.GRU)
38+
39+
if op_type in LSTM_CLASSES:
3140
lstm.convert_keras_lstm(scope, operator, container, bidirectional)
32-
elif op_type == keras.layers.GRU:
41+
elif op_type in GRU_CLASSES:
3342
gru.convert_keras_gru(scope, operator, container, bidirectional)
3443
elif op_type == keras.layers.SimpleRNN:
3544
simplernn.convert_keras_simple_rnn(scope, operator, container, bidirectional)

tests/test_layers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@
7171

7272
RNN_CLASSES = [SimpleRNN, GRU, LSTM]
7373

74+
if is_tf_keras and is_tensorflow_later_than("1.14.0"):
75+
# Add the TF v2 compatability layers (available after TF 1.14)
76+
from tensorflow.python.keras.layers import recurrent_v2
77+
RNN_CLASSES.extend([recurrent_v2.GRU, recurrent_v2.LSTM])
78+
7479

7580
def _asarray(*a):
7681
return np.array([a], dtype='f')

0 commit comments

Comments
 (0)