Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Support tensorflow 2.2 #484

Merged
merged 9 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure-pipelines/linux-CI-keras-applications-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
ONNX_PATH: onnx==1.6.0
INSTALL_KERAS:
UNINSTALL_KERAS: pip uninstall keras -y
INSTALL_TENSORFLOW: pip install tensorflow==2.1.0
INSTALL_TENSORFLOW: pip install tensorflow==2.2.0
INSTALL_ORT: pip install -i https://test.pypi.org/simple/ ort-nightly
INSTALL_KERAS_RESNET: pip install keras-resnet
INSTALL_TRANSFORMERS: pip install transformers
Expand Down
2 changes: 1 addition & 1 deletion .azure-pipelines/win32-CI-keras-applications-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
ONNX_PATH: onnx==1.6.0
INSTALL_KERAS:
UNINSTALL_KERAS: pip uninstall keras -y
INSTALL_TENSORFLOW: pip install tensorflow==2.1.0
INSTALL_TENSORFLOW: pip install tensorflow==2.2.0
INSTALL_ORT: pip install onnxruntime==1.1.1
INSTALL_KERAS_RESNET: pip install keras-resnet
INSTALL_TRANSFORMERS: pip install transformers
Expand Down
1 change: 0 additions & 1 deletion applications/nightly_build/test_keras_applications_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_InceptionV3(self):
keras.backend.set_learning_phase(0)
InceptionV3 = keras.applications.inception_v3.InceptionV3
model = InceptionV3(include_top=True)
model.save('inception.h5')
res = run_image(model, self.model_files, img_path, target_size=299, tf_v2=True)
self.assertTrue(*res)

Expand Down
28 changes: 22 additions & 6 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,20 +482,28 @@ def _get_output_nodes(node_list, node):
return [n_ for n_ in node_list if n_ not in nodes_has_children] # need to keep the order.


def _filter_out_input(node_name):
# tf.keras BN layer sometimes create a placeholder node 'scale' in tf2.x.
# Given bn layer will be converted in a whole layer, it's fine to just filter this node out.
filter_out = re.match(r"batch_normalization_\d+\/scale$", node_name)
filter_out = filter_out or re.match(r"batch_normalization_\d+\/cond/input", node_name) # since tf 2.2
return filter_out


def _advance_by_input(cur_node, layer_nodes, subgraph, inputs, graph_inputs, q_overall):
for input_ in cur_node.inputs:
predecessor = input_.op
if is_placeholder_node(predecessor):
# tf.keras BN layer sometimes create a placeholder node 'scale' in tf2.x.
# Given bn layer will be converted in a whole layer, it's fine to just filter this node out.
if not re.match(r"batch_normalization_\d+\/scale$", predecessor.name):
if not _filter_out_input(predecessor.name):
inputs.add(predecessor)
graph_inputs.add(predecessor)
continue
if predecessor in layer_nodes or len(layer_nodes) == 0:
subgraph.append(predecessor)
else:
inputs.add(predecessor)
q_overall.put_nowait(predecessor)
if not _filter_out_input(predecessor.name):
inputs.add(predecessor)
q_overall.put_nowait(predecessor)


def _visit_nodelist(activated_keras_nodes, input_nodes, layer_key,
Expand Down Expand Up @@ -735,7 +743,15 @@ def parse_graph_modeless(topo, graph, target_opset, input_names, output_names, k

for ts_i_ in input_tensors:
var_type = _adjust_input_batch_size(infer_variable_type(ts_i_, target_opset))
str_value = ts_i_.name
if ts_i_.name.endswith(':0'):
str_value = ts_i_.name[:-2]
op = top_level.declare_local_operator(TYPES.Identity)
var0 = top_level.get_local_variable_or_declare_one(str_value, var_type)
var1 = top_level.get_local_variable_or_declare_one(ts_i_.name, var_type)
op.add_input(var0)
op.add_output(var1)
else:
str_value = ts_i_.name
top_level.get_local_variable_or_declare_one(str_value, var_type)
topo.raw_model.add_input_name(str_value)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.
if expected is None:
return

if not isinstance(expected, list):
if isinstance(expected, tuple):
expected = list(expected)
elif not isinstance(expected, list):
expected = [expected]

res = all(np.allclose(expected[n_], actual[n_], rtol=rtol, atol=atol) for n_ in range(len(expected)))
Expand Down