Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit babb949

Browse files
authored
Support tf.nn.leaky_relu and fix advanced_activations (#514)
1 parent 85e8132 commit babb949

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

keras2onnx/ke2onnx/activation.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import numpy as np
77
import tensorflow as tf
88
from ..proto import keras, is_tf_keras
9-
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_relu, apply_relu_6, apply_sigmoid, apply_tanh, \
10-
apply_softmax, apply_identity, apply_selu, apply_mul
9+
from ..common.onnx_ops import apply_elu, apply_hard_sigmoid, apply_leaky_relu, apply_relu, apply_relu_6, \
10+
apply_tanh, apply_softmax, apply_identity, apply_selu, apply_mul, apply_prelu, apply_sigmoid
1111
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
1212

1313
activation_get = keras.activations.get
@@ -21,6 +21,11 @@
2121
if not relu6 and hasattr(keras.applications.mobilenet, 'relu6'):
2222
relu6 = keras.applications.mobilenet.relu6
2323

24+
25+
def apply_leaky_relu_keras(scope, input_name, output_name, container, operator_name=None, alpha=0.2):
26+
apply_leaky_relu(scope, input_name, output_name, container, operator_name, alpha)
27+
28+
2429
activation_map = {activation_get('sigmoid'): apply_sigmoid,
2530
activation_get('softmax'): apply_softmax,
2631
activation_get('linear'): apply_identity,
@@ -29,6 +34,7 @@
2934
activation_get('selu'): apply_selu,
3035
activation_get('tanh'): apply_tanh,
3136
activation_get('hard_sigmoid'): apply_hard_sigmoid,
37+
tf.nn.leaky_relu: apply_leaky_relu_keras,
3238
tf.nn.sigmoid: apply_sigmoid,
3339
tf.nn.softmax: apply_softmax,
3440
tf.nn.relu: apply_relu,
@@ -40,6 +46,7 @@
4046
if hasattr(tf.compat, 'v1'):
4147
activation_map.update({tf.compat.v1.nn.sigmoid: apply_sigmoid})
4248
activation_map.update({tf.compat.v1.nn.softmax: apply_softmax})
49+
activation_map.update({tf.compat.v1.nn.leaky_relu: apply_leaky_relu_keras})
4350
activation_map.update({tf.compat.v1.nn.relu: apply_relu})
4451
activation_map.update({tf.compat.v1.nn.relu6: apply_relu_6})
4552
activation_map.update({tf.compat.v1.nn.elu: apply_elu})
@@ -51,29 +58,38 @@ def convert_keras_activation(scope, operator, container):
5158
input_name = operator.input_full_names[0]
5259
output_name = operator.output_full_names[0]
5360
activation = operator.raw_operator.activation
61+
activation_type = type(activation)
5462
if activation in [activation_get('sigmoid'), keras.activations.sigmoid]:
5563
apply_sigmoid(scope, input_name, output_name, container)
5664
elif activation in [activation_get('tanh'), keras.activations.tanh]:
5765
apply_tanh(scope, input_name, output_name, container)
58-
elif activation in [activation_get('relu'), keras.activations.relu]:
66+
elif activation in [activation_get('relu'), keras.activations.relu] or \
67+
(hasattr(keras.layers.advanced_activations, 'ReLU') and
68+
activation_type == keras.layers.advanced_activations.ReLU):
5969
apply_relu(scope, input_name, output_name, container)
60-
elif activation in [activation_get('softmax'), keras.activations.softmax]:
70+
elif activation in [activation_get('softmax'), keras.activations.softmax] or \
71+
activation_type == keras.layers.advanced_activations.Softmax:
6172
apply_softmax(scope, input_name, output_name, container, axis=-1)
62-
elif activation in [activation_get('elu'), keras.activations.elu]:
73+
elif activation in [activation_get('elu'), keras.activations.elu] or \
74+
activation_type == keras.layers.advanced_activations.ELU:
6375
apply_elu(scope, input_name, output_name, container, alpha=1.0)
6476
elif activation in [activation_get('hard_sigmoid'), keras.activations.hard_sigmoid]:
6577
apply_hard_sigmoid(scope, input_name, output_name, container, alpha=0.2, beta=0.5)
6678
elif activation in [activation_get('linear'), keras.activations.linear]:
6779
apply_identity(scope, input_name, output_name, container)
6880
elif activation in [activation_get('selu'), keras.activations.selu]:
6981
apply_selu(scope, input_name, output_name, container, alpha=1.673263, gamma=1.050701)
70-
elif activation in [relu6] or activation.__name__ == 'relu6':
82+
elif activation_type == keras.layers.advanced_activations.LeakyReLU:
83+
apply_leaky_relu(scope, input_name, output_name, container, alpha=activation.alpha.item(0))
84+
elif activation_type == keras.layers.advanced_activations.PReLU:
85+
apply_prelu(scope, input_name, output_name, container, slope=operator.raw_operator.get_weights()[0])
86+
elif activation in [relu6] or (hasattr(activation, '__name__') and activation.__name__ == 'relu6'):
7187
# relu6(x) = min(relu(x), 6)
7288
np_type = TENSOR_TYPE_TO_NP_TYPE[operator.inputs[0].type.to_onnx_type().tensor_type.elem_type]
7389
zero_value = np.zeros(shape=(1,), dtype=np_type)
7490
apply_relu_6(scope, input_name, output_name, container,
7591
zero_value=zero_value)
76-
elif activation.__name__ in ['swish']:
92+
elif hasattr(activation, '__name__') and activation.__name__ == 'swish':
7793
apply_sigmoid(scope, input_name, output_name + '_sig', container)
7894
apply_mul(scope, [input_name, output_name + '_sig'], output_name, container)
7995
else:

keras2onnx/ke2onnx/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in
197197

198198
# The construction of convolution is done. Now, we create an activation operator to apply the activation specified
199199
# in this Keras layer.
200-
if op.activation.__name__ == 'swish':
200+
if hasattr(op.activation, '__name__') and op.activation.__name__ == 'swish':
201201
apply_sigmoid(scope, transpose_output_name, transpose_output_name + '_sig', container)
202202
apply_mul(scope, [transpose_output_name, transpose_output_name + '_sig'], operator.outputs[0].full_name,
203203
container)

tests/test_layers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,13 +1373,19 @@ def test_Softmax(advanced_activation_runner):
13731373
advanced_activation_runner(layer, data)
13741374

13751375

1376+
@pytest.mark.skipif(is_tensorflow_older_than('1.14.0') and is_tf_keras, reason='old tf version')
13761377
def test_tf_nn_activation(runner):
1377-
for activation in [tf.nn.relu, 'relu', tf.nn.relu6, tf.nn.softmax]:
1378+
for activation in ['relu', tf.nn.relu, tf.nn.relu6, tf.nn.softmax, tf.nn.leaky_relu]:
13781379
model = keras.Sequential([
13791380
Dense(64, activation=activation, input_shape=[10]),
13801381
Dense(64, activation=activation),
13811382
Dense(1)
13821383
])
1384+
if is_tf_keras:
1385+
model.add(Activation(tf.keras.layers.LeakyReLU(alpha=0.2)))
1386+
model.add(Activation(tf.keras.layers.ReLU()))
1387+
model.add(tf.keras.layers.PReLU())
1388+
model.add(tf.keras.layers.LeakyReLU(alpha=0.5))
13831389
x = np.random.rand(5, 10).astype(np.float32)
13841390
expected = model.predict(x)
13851391
onnx_model = keras2onnx.convert_keras(model, model.name)

0 commit comments

Comments
 (0)