Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Support for recurrent_v2 layers for TF 2.0 #610

Merged
merged 1 commit into from
Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions keras2onnx/ke2onnx/bidirectional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import collections
import numbers
from ..common import cvtfunc
from ..proto import keras
from ..proto import keras, is_tf_keras, is_tensorflow_later_than
from . import simplernn, gru, lstm

LSTM_CLASSES = {keras.layers.LSTM}
GRU_CLASSES = {keras.layers.GRU}


def _calculate_keras_bidirectional_output_shapes(operator):
op = operator.raw_operator
Expand All @@ -27,9 +30,15 @@ def convert_bidirectional(scope, operator, container):
op_type = type(operator.raw_operator.forward_layer)
bidirectional = True

if op_type == keras.layers.LSTM:
if is_tf_keras and is_tensorflow_later_than("1.14.0"):
# Add the TF v2 compatability layers (available after TF 1.14)
from tensorflow.python.keras.layers import recurrent_v2
LSTM_CLASSES.add(recurrent_v2.LSTM)
GRU_CLASSES.add(recurrent_v2.GRU)

if op_type in LSTM_CLASSES:
lstm.convert_keras_lstm(scope, operator, container, bidirectional)
elif op_type == keras.layers.GRU:
elif op_type in GRU_CLASSES:
gru.convert_keras_gru(scope, operator, container, bidirectional)
elif op_type == keras.layers.SimpleRNN:
simplernn.convert_keras_simple_rnn(scope, operator, container, bidirectional)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@

RNN_CLASSES = [SimpleRNN, GRU, LSTM]

if is_tf_keras and is_tensorflow_later_than("1.14.0"):
# Add the TF v2 compatability layers (available after TF 1.14)
from tensorflow.python.keras.layers import recurrent_v2
RNN_CLASSES.extend([recurrent_v2.GRU, recurrent_v2.LSTM])


def _asarray(*a):
return np.array([a], dtype='f')
Expand Down