Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 74e855a

Browse files
authored
Convert tf EinSum, OneHot, LogicalAnd/Not etc (#449)
1 parent 9316ad6 commit 74e855a

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

keras2onnx/_builtin.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class TYPES:
3434
Conv2D = 'Conv2D'
3535
Cumsum = 'Cumsum'
3636
DepthwiseConv2dNative = 'DepthwiseConv2dNative'
37+
Einsum = 'Einsum'
3738
ExpandDims = 'ExpandDims'
3839
Fill = 'Fill'
3940
FloorDiv = 'FloorDiv'
@@ -44,6 +45,8 @@ class TYPES:
4445
GatherV2 = 'GatherV2'
4546
GreaterEqual = 'GreaterEqual'
4647
LessEqual = 'LessEqual'
48+
LogicalAnd = 'LogicalAnd'
49+
LogicalNot = 'LogicalNot'
4750
LogSoftmax = 'LogSoftmax'
4851
MatMul = 'MatMul'
4952
Max = 'Max'
@@ -54,6 +57,7 @@ class TYPES:
5457
NonMaxSuppressionV2 = 'NonMaxSuppressionV2'
5558
NonMaxSuppressionV3 = 'NonMaxSuppressionV3'
5659
NotEqual = 'NotEqual'
60+
OneHot = 'OneHot'
5761
Pack = 'Pack'
5862
Pad = 'Pad'
5963
PadV2 = 'PadV2'
@@ -742,6 +746,20 @@ def convert_tf_conv2d(scope, operator, container):
742746
_convert_tf_conv2d(scope, operator, container)
743747

744748

749+
@converter_func(TYPES.Einsum)
750+
def convert_tf_einsum(scope, operator, container):
751+
if operator.target_opset < 12:
752+
raise ValueError("Einsum op is not supported until opset 12")
753+
oopb = OnnxOperatorBuilder(container, scope)
754+
node = operator.raw_operator
755+
equation_str = node.get_attr('equation').decode("utf-8")
756+
oopb.add_node_with_output("Einsum",
757+
operator.input_full_names,
758+
operator.output_full_names,
759+
name=operator.full_name,
760+
equation=equation_str)
761+
762+
745763
@converter_func(TYPES.ExpandDims)
746764
def convert_tf_expand_dims(scope, operator, container):
747765
oopb = OnnxOperatorBuilder(container, scope)
@@ -940,6 +958,24 @@ def convert_tf_less_equal(scope, operator, container):
940958
_convert_tf_compare_equal(scope, operator, container, 'LessEqual', 'Greater')
941959

942960

961+
@converter_func(TYPES.LogicalAnd)
962+
def convert_tf_logical_not(scope, operator, container):
963+
oopb = OnnxOperatorBuilder(container, scope)
964+
oopb.add_node_with_output('And',
965+
operator.input_full_names,
966+
operator.output_full_names,
967+
name=operator.full_name)
968+
969+
970+
@converter_func(TYPES.LogicalNot)
971+
def convert_tf_logical_not(scope, operator, container):
972+
oopb = OnnxOperatorBuilder(container, scope)
973+
oopb.add_node_with_output('Not',
974+
operator.input_full_names,
975+
operator.output_full_names,
976+
name=operator.full_name)
977+
978+
943979
@converter_func(TYPES.LogSoftmax)
944980
def convert_tf_logsoftmax(scope, operator, container):
945981
oopb = OnnxOperatorBuilder(container, scope)
@@ -1713,6 +1749,32 @@ def convert_tf_not_equal(scope, operator, container):
17131749
name=operator.full_name + '_not')
17141750

17151751

1752+
1753+
@converter_func(TYPES.OneHot)
1754+
def convert_tf_one_hot(scope, operator, container):
1755+
if operator.target_opset < 9:
1756+
raise ValueError("OneHot op is not supported until opset 9")
1757+
oopb = OnnxOperatorBuilder(container, scope)
1758+
node = operator.raw_operator
1759+
axis = node.get_attr('axis')
1760+
1761+
depth = oopb.apply_unsqueeze(operator.inputs[1].full_name,
1762+
name=operator.full_name + '_unsqueeze_1',
1763+
axes=[0])
1764+
on_value = oopb.apply_unsqueeze(operator.inputs[2].full_name,
1765+
name=operator.full_name + '_unsqueeze_2',
1766+
axes=[0])
1767+
off_value = oopb.apply_unsqueeze(operator.inputs[3].full_name,
1768+
name=operator.full_name + '_unsqueeze_3',
1769+
axes=[0])
1770+
off_on_value = oopb.apply_concat(off_value + on_value,
1771+
name=operator.full_name + '_concat',
1772+
axis=0)
1773+
oopb.add_node_with_output('OneHot', [operator.inputs[0].full_name] + depth + off_on_value,
1774+
operator.output_full_names,
1775+
name=operator.full_name + '_one_hot', axis=axis)
1776+
1777+
17161778
@converter_func(TYPES.ReadVariableOp)
17171779
def convert_tf_read_variable_op(scope, operator, container):
17181780
oopb = OnnxOperatorBuilder(container, scope)
@@ -2096,6 +2158,8 @@ def convert_tf_zeros_like(scope, operator, container):
20962158
"Erf": 9,
20972159
"Exp": ("apply_exp",),
20982160
"Floor": ("apply_floor",),
2161+
"Greater": ("apply_greater",),
2162+
"Less": ("apply_less",),
20992163
"Log": ("apply_log",),
21002164
"Mul": ("apply_mul",),
21012165
"Neg": ("apply_neg",),

tests/test_layers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,21 @@ def my_func_4(x):
305305
assert runner('tf_maximum_minimum', onnx_model, [data1, data2], expected)
306306

307307

308+
@pytest.mark.skipif(get_maximum_opset_supported() < 9,
309+
reason="opset < 9 is not supported.")
310+
def test_tf_one_hot(runner):
311+
def my_func(x):
312+
return tf.one_hot(tf.cast(x, tf.int32), 3, 5.0, -1.0, 1)
313+
314+
model = Sequential()
315+
model.add(Lambda(lambda x: my_func(x), input_shape=[3]))
316+
onnx_model = keras2onnx.convert_keras(model, 'test_tf_one_hot')
317+
keras2onnx.save_model(onnx_model, 'one_hot.onnx')
318+
data = np.array([[0, 1, 2]]).astype(np.float32)
319+
expected = model.predict(data)
320+
assert runner('tf_one_hot', onnx_model, data, expected)
321+
322+
308323
def test_tf_pad(runner):
309324
def my_func_1(x):
310325
paddings = tf.constant([[0, 0], [1, 3], [2, 4]])

0 commit comments

Comments
 (0)