Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit ee2945c

Browse files
authored
Fix LSTM layer conversion in tf 2.x (#412)
* The sequential model and tf2.2 issue fixing. * more adjustment. * fixing the import issue. * remove keras/tf2.0 combination. * exprimentals * revert the exprimental change * more revert
1 parent bab7030 commit ee2945c

File tree

6 files changed

+82
-40
lines changed

6 files changed

+82
-40
lines changed

.azure-pipelines/win32-conda-CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
python.version: '3.6'
2626
ONNX_PATH: onnx==1.5.0
2727
KERAS: keras==2.2.5
28-
TENSORFLOW_PATH: tensorflow==1.14.0
28+
TENSORFLOW_PATH: tensorflow==1.15.0
2929
INSTALL_ORT: pip install onnxruntime==1.1.1
3030

3131
Python37-tf200:

applications/nightly_build/test_transformers.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import keras2onnx
1010
import json
1111
from os.path import dirname, abspath
12+
1213
sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../tests/'))
1314
from test_utils import run_onnx_runtime
1415
from keras2onnx.proto import is_tensorflow_older_than
@@ -18,7 +19,7 @@
1819
enable_transformer_test = True
1920

2021

21-
@unittest.skipIf(is_tensorflow_older_than('2.1.0') or not enable_transformer_test,
22+
@unittest.skipIf(not enable_transformer_test,
2223
"Need enable transformer test before Transformers conversion.")
2324
class TestTransformers(unittest.TestCase):
2425

@@ -38,6 +39,18 @@ def _prepare_inputs(self, tokenizer):
3839
inputs_onnx = {k_: v_.numpy() for k_, v_ in inputs.items()}
3940
return text, inputs, inputs_onnx
4041

42+
def test_3layer_gpt2(self):
43+
from transformers import GPT2Config, TFGPT2Model, BertTokenizer
44+
keras2onnx.proto.keras.backend.set_learning_phase(0)
45+
config = GPT2Config(n_layer=3)
46+
model = TFGPT2Model(config)
47+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
48+
text, inputs, inputs_onnx = self._prepare_inputs(tokenizer)
49+
inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='tf')
50+
predictions = model.predict(inputs)
51+
onnx_model = keras2onnx.convert_keras(model, model.name)
52+
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files))
53+
4154
def test_TFBertModel(self):
4255
from transformers import BertTokenizer, TFBertModel
4356
pretrained_weights = 'bert-base-uncased'
@@ -56,7 +69,9 @@ def test_TFBertForPreTraining(self):
5669
model = TFBertForPreTraining.from_pretrained(pretrained_weights)
5770
predictions = model.predict(inputs)
5871
onnx_model = keras2onnx.convert_keras(model, model.name)
59-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
72+
self.assertTrue(
73+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
74+
atol=1.e-4))
6075

6176
def test_TFBertForMaskedLM(self):
6277
from transformers import BertTokenizer, TFBertForMaskedLM
@@ -66,7 +81,9 @@ def test_TFBertForMaskedLM(self):
6681
model = TFBertForMaskedLM.from_pretrained(pretrained_weights)
6782
predictions = model.predict(inputs)
6883
onnx_model = keras2onnx.convert_keras(model, model.name)
69-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
84+
self.assertTrue(
85+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
86+
atol=1.e-4))
7087

7188
def test_TFBertForNextSentencePrediction(self):
7289
from transformers import BertTokenizer, TFBertForNextSentencePrediction
@@ -146,7 +163,9 @@ def test_TFXLMModel(self):
146163
model = TFXLMModel.from_pretrained(pretrained_weights)
147164
predictions = model.predict(inputs)
148165
onnx_model = keras2onnx.convert_keras(model, model.name)
149-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
166+
self.assertTrue(
167+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
168+
atol=1.e-4))
150169

151170
def test_TFXLMWithLMHeadModel(self):
152171
from transformers import XLMTokenizer, TFXLMWithLMHeadModel
@@ -156,7 +175,9 @@ def test_TFXLMWithLMHeadModel(self):
156175
model = TFXLMWithLMHeadModel.from_pretrained(pretrained_weights)
157176
predictions = model.predict(inputs)
158177
onnx_model = keras2onnx.convert_keras(model, model.name)
159-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
178+
self.assertTrue(
179+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
180+
atol=1.e-4))
160181

161182
def test_TFXLMForSequenceClassification(self):
162183
from transformers import XLMTokenizer, TFXLMForSequenceClassification
@@ -196,7 +217,9 @@ def test_TFDistilBertForMaskedLM(self):
196217
model = TFDistilBertForMaskedLM.from_pretrained(pretrained_weights)
197218
predictions = model.predict(inputs)
198219
onnx_model = keras2onnx.convert_keras(model, model.name)
199-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
220+
self.assertTrue(
221+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
222+
atol=1.e-4))
200223

201224
def test_TFDistilBertForSequenceClassification(self):
202225
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
@@ -246,7 +269,9 @@ def test_TFRobertaForMaskedLM(self):
246269
model = TFRobertaForMaskedLM.from_pretrained(pretrained_weights)
247270
predictions = model.predict(inputs)
248271
onnx_model = keras2onnx.convert_keras(model, model.name)
249-
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2, atol=1.e-4))
272+
self.assertTrue(
273+
run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files, rtol=1.e-2,
274+
atol=1.e-4))
250275

251276
def test_TFRobertaForSequenceClassification(self):
252277
from transformers import RobertaTokenizer, TFRobertaForSequenceClassification

keras2onnx/_graph_cvt.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _save_placeholder(node_name, dtype):
494494
# Get dtype and data for non-variable Placeholders (ex. values for 1.X
495495
# Const ops that are loaded as Placeholders in 2.0)
496496
_save_placeholder(node.name, node.attr["dtype"])
497-
elif node.op in ["ReadVariableOp", "ResourceGather", "AssignSubVariableOp"]:
497+
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd", "AssignSubVariableOp"]:
498498
# Get dtype and data for Placeholder ops associated with ReadVariableOp
499499
# and ResourceGather ops. There can be an Identity in between the
500500
# resource op and Placeholder. Store the dtype for the Identity ops.
@@ -532,12 +532,12 @@ def _save_placeholder(node_name, dtype):
532532
_populate_identity_op(output_node, input_node)
533533
# Convert ResourceGather to Gather ops with a Const axis feeding into it.
534534
elif input_node.op == "AssignSubVariableOp":
535-
output_node.op = "Sub"
536-
output_node.name = input_node.name
537-
output_node.input.extend(input_node.input)
538-
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
539-
if "_class" in input_node.attr:
540-
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
535+
output_node.op = "Sub"
536+
output_node.name = input_node.name
537+
output_node.input.extend(input_node.input)
538+
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
539+
if "_class" in input_node.attr:
540+
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
541541
elif input_node.op == "ResourceGather":
542542
if input_node.attr["batch_dims"].i != 0:
543543
raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
@@ -557,6 +557,15 @@ def _save_placeholder(node_name, dtype):
557557
output_node.attr["Taxis"].CopyFrom(axis_dtype)
558558
if "_class" in input_node.attr:
559559
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
560+
elif input_node.op == "ResourceGatherNd":
561+
output_node.op = "GatherNd"
562+
output_node.name = input_node.name
563+
output_node.input.extend(
564+
[input_node.input[0], input_node.input[1]])
565+
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
566+
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
567+
if "_class" in input_node.attr:
568+
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
560569
# Update the function names and argument types for the conditional ops.
561570
elif input_node.op in _CONDITIONAL_OPS:
562571
_populate_if_op(output_node, input_node, function_data)
@@ -625,5 +634,4 @@ def _save_placeholder(node_name, dtype):
625634
output_node.input[idx] = input_name
626635

627636
output_graph_def.versions.CopyFrom(graph_def.versions)
628-
return _construct_concrete_function(func, output_graph_def,
629-
converted_input_indices)
637+
return output_graph_def, converted_input_indices

keras2onnx/_parse_tf.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -204,37 +204,18 @@ def build_layer_outputs(model, graph, outputs):
204204
return output_dict
205205

206206

207-
TF_GRAPH_OPTIMIZATION = False
208-
209-
210207
def extract_outputs_from_subclassing_model(model, output_dict, output_names):
211-
from tensorflow.core.protobuf import config_pb2
212208
from tensorflow.python.keras.saving import saving_utils as _saving_utils
213-
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
209+
from tensorflow.python.util import object_identity
214210
from ._graph_cvt import convert_variables_to_constants_v2 as _convert_to_constants
215211

216212
function = _saving_utils.trace_model_call(model)
217213
concrete_func = function.get_concrete_function()
218214
output_names.extend([ts_.name for ts_ in concrete_func.outputs])
219215
output_dict.update(build_layer_outputs(model, concrete_func.graph, concrete_func.outputs))
220-
frozen_func = _convert_to_constants(
216+
graph_def, converted_input_indices = _convert_to_constants(
221217
concrete_func, lower_control_flow=True)
222-
graph_def = frozen_func.graph.as_graph_def()
223-
if TF_GRAPH_OPTIMIZATION:
224-
input_tensors = [
225-
tensor for tensor in frozen_func.inputs
226-
if tensor.dtype != tf.dtypes.resource
227-
]
228-
output_tensors = frozen_func.outputs
229-
config = config_pb2.ConfigProto()
230-
rewrite_options = config.graph_options.rewrite_options
231-
rewrite_options.constant_folding = rewrite_options.ON
232-
graph_def = _run_graph_optimizations(
233-
graph_def,
234-
input_tensors,
235-
output_tensors,
236-
config=config,
237-
graph=frozen_func.graph)
218+
238219
with tf.Graph().as_default() as tf_graph:
239220
tf.import_graph_def(graph_def, name='')
240221

keras2onnx/ke2onnx/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def convert_keras_training_only_layer(scope, operator, container):
216216

217217
if is_tf_keras and is_tf2:
218218
keras_layer_to_operator.update({
219+
_layer.recurrent_v2.GRU: convert_keras_gru,
220+
_layer.recurrent_v2.LSTM: convert_keras_lstm,
219221
_layer.normalization_v2.BatchNormalization: convert_keras_batch_normalization,
220222
})
221223

keras2onnx/parser.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,11 +605,15 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
605605

606606
def _sorted_inputs(nodelist, outputs, inputs_set):
607607
inputs = []
608-
node_set = set(nodelist)
608+
node_set = frozenset(nodelist)
609+
visited = set()
609610

610611
def travel(node):
611612
for in_ts_ in node.inputs:
612613
op_node = in_ts_.op
614+
if op_node in visited:
615+
continue
616+
visited.add(op_node)
613617
if (op_node in inputs_set) and (op_node not in inputs):
614618
inputs.append(op_node)
615619
elif op_node in node_set:
@@ -715,6 +719,28 @@ def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_nam
715719
return topology
716720

717721

722+
def parse_graph_modeless(topo, graph, target_opset, input_names, output_names, keras_node_dict):
723+
top_level = topo.declare_scope('__root')
724+
input_tensors = [graph.get_tensor_by_name(n_) for n_ in input_names]
725+
output_tensors = [graph.get_tensor_by_name(n_) for n_ in output_names]
726+
727+
for ts_i_ in input_tensors:
728+
var_type = _adjust_input_batch_size(infer_variable_type(ts_i_, target_opset))
729+
str_value = ts_i_.name
730+
top_level.get_local_variable_or_declare_one(str_value, var_type)
731+
topo.raw_model.add_input_name(str_value)
732+
733+
for ts_o_ in output_tensors:
734+
var_type = _adjust_input_batch_size(infer_variable_type(ts_o_, target_opset))
735+
str_value = ts_o_.name
736+
top_level.get_local_variable_or_declare_one(str_value, var_type)
737+
topo.raw_model.add_output_name(str_value)
738+
739+
return _parse_graph_core_v2(
740+
graph, keras_node_dict, topo, top_level, output_names
741+
)
742+
743+
718744
def parse_graph(topo, graph, target_opset, output_names, keras_node_dict):
719745
# type: (Topology, tf.Graph, int, [], []) -> Topology
720746
"""

0 commit comments

Comments
 (0)