Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 4abc0c9

Browse files
authored
[Opset 10] Updating ThresholdedRelu and onnxconverter-common package (#96)
- Updated onnxconverter-common versioning to 1.5.0 - ThresholdedRelu graduated from experimental op to full op in Opset 10 -- adding support for this new change and including test to ensure backwards compatibility
1 parent 20e461e commit 4abc0c9

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

keras2onnx/ke2onnx/adv_activation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
###############################################################################
66
from ..proto import keras
77
from distutils.version import StrictVersion
8-
from ..common.onnx_ops import apply_elu, apply_leaky_relu, apply_prelu
8+
from ..common.onnx_ops import apply_elu, apply_leaky_relu, apply_prelu, apply_thresholded_relu
99

1010

1111
activations = keras.layers.advanced_activations
@@ -25,14 +25,15 @@ def convert_keras_advanced_activation(scope, operator, container):
2525
weights = op.get_weights()[0]
2626
apply_prelu(scope, operator.input_full_names[0], operator.output_full_names[0], container,
2727
operator_name=operator.full_name, slope=weights)
28+
elif isinstance(op, activations.ThresholdedReLU):
29+
alpha = op.get_config()['theta']
30+
apply_thresholded_relu(scope, operator.input_full_names[0], operator.output_full_names[0], container,
31+
operator_name=operator.full_name, alpha=[alpha])
2832
else:
2933
attrs = {'name': operator.full_name}
3034
ver_opset = 6
3135
input_tensor_names = [operator.input_full_names[0]]
32-
if isinstance(op, activations.ThresholdedReLU):
33-
op_type = 'ThresholdedRelu'
34-
attrs['alpha'] = op.get_config()['theta']
35-
elif StrictVersion(keras.__version__) >= StrictVersion('2.1.3') and \
36+
if StrictVersion(keras.__version__) >= StrictVersion('2.1.3') and \
3637
isinstance(op, activations.Softmax):
3738
op_type = 'Softmax'
3839
attrs['axis'] = op.get_config()['axis']
@@ -44,4 +45,4 @@ def convert_keras_advanced_activation(scope, operator, container):
4445
else:
4546
raise RuntimeError('Unsupported advanced layer found %s' % type(op))
4647

47-
container.add_node(op_type, input_tensor_names, operator.output_full_names, op_version=ver_opset, **attrs)
48+
container.add_node(op_type, input_tensor_names, operator.output_full_names, op_version=ver_opset, **attrs)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ protobuf
33
keras
44
requests
55
onnx
6-
onnxconverter-common>=1.4.0
6+
onnxconverter-common>=1.5.0

tests/test_layers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import unittest
1010
import keras2onnx
1111
import numpy as np
12-
from keras2onnx.proto import keras, is_tf_keras
12+
from keras2onnx.proto import keras, is_tf_keras, get_opset_number_from_onnx
1313
from distutils.version import StrictVersion
1414

1515

@@ -382,7 +382,9 @@ def test_pooling_3d(self):
382382
def test_pooling_global(self):
383383
self._pooling_test_helper(keras.layers.GlobalAveragePooling2D, (4, 6, 2))
384384

385-
def activationlayer_helper(self, layer, data_for_advanced_layer=None):
385+
def activationlayer_helper(self, layer, data_for_advanced_layer=None, op_version=None):
386+
if op_version is None:
387+
op_version = get_opset_number_from_onnx()
386388
if data_for_advanced_layer is None:
387389
data = self.asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
388390
layer = keras.layers.Activation(layer, input_shape=(data.size,))
@@ -391,7 +393,7 @@ def activationlayer_helper(self, layer, data_for_advanced_layer=None):
391393

392394
model = keras.Sequential()
393395
model.add(layer)
394-
onnx_model = keras2onnx.convert_keras(model, model.name)
396+
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=op_version)
395397

396398
expected = model.predict(data)
397399
self.assertTrue(self.run_onnx_runtime(onnx_model.graph.name, onnx_model, data, expected))
@@ -444,6 +446,8 @@ def test_LeakyRelu(self):
444446
def test_ThresholdedRelu(self):
445447
data = self.asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
446448
layer = keras.layers.advanced_activations.ThresholdedReLU(theta=1.0, input_shape=(data.size,))
449+
self.activationlayer_helper(layer, data, op_version=8)
450+
layer = keras.layers.advanced_activations.ThresholdedReLU(theta=1.0, input_shape=(data.size,))
447451
self.activationlayer_helper(layer, data)
448452

449453
def test_ELU(self):

0 commit comments

Comments
 (0)