@@ -34,6 +34,7 @@ class TYPES:
34
34
Conv2D = 'Conv2D'
35
35
Cumsum = 'Cumsum'
36
36
DepthwiseConv2dNative = 'DepthwiseConv2dNative'
37
+ Einsum = 'Einsum'
37
38
ExpandDims = 'ExpandDims'
38
39
Fill = 'Fill'
39
40
FloorDiv = 'FloorDiv'
@@ -44,6 +45,8 @@ class TYPES:
44
45
GatherV2 = 'GatherV2'
45
46
GreaterEqual = 'GreaterEqual'
46
47
LessEqual = 'LessEqual'
48
+ LogicalAnd = 'LogicalAnd'
49
+ LogicalNot = 'LogicalNot'
47
50
LogSoftmax = 'LogSoftmax'
48
51
MatMul = 'MatMul'
49
52
Max = 'Max'
@@ -54,6 +57,7 @@ class TYPES:
54
57
NonMaxSuppressionV2 = 'NonMaxSuppressionV2'
55
58
NonMaxSuppressionV3 = 'NonMaxSuppressionV3'
56
59
NotEqual = 'NotEqual'
60
+ OneHot = 'OneHot'
57
61
Pack = 'Pack'
58
62
Pad = 'Pad'
59
63
PadV2 = 'PadV2'
@@ -742,6 +746,20 @@ def convert_tf_conv2d(scope, operator, container):
742
746
_convert_tf_conv2d (scope , operator , container )
743
747
744
748
749
+ @converter_func (TYPES .Einsum )
750
+ def convert_tf_einsum (scope , operator , container ):
751
+ if operator .target_opset < 12 :
752
+ raise ValueError ("Einsum op is not supported until opset 12" )
753
+ oopb = OnnxOperatorBuilder (container , scope )
754
+ node = operator .raw_operator
755
+ equation_str = node .get_attr ('equation' ).decode ("utf-8" )
756
+ oopb .add_node_with_output ("Einsum" ,
757
+ operator .input_full_names ,
758
+ operator .output_full_names ,
759
+ name = operator .full_name ,
760
+ equation = equation_str )
761
+
762
+
745
763
@converter_func (TYPES .ExpandDims )
746
764
def convert_tf_expand_dims (scope , operator , container ):
747
765
oopb = OnnxOperatorBuilder (container , scope )
@@ -940,6 +958,24 @@ def convert_tf_less_equal(scope, operator, container):
940
958
_convert_tf_compare_equal (scope , operator , container , 'LessEqual' , 'Greater' )
941
959
942
960
961
+ @converter_func (TYPES .LogicalAnd )
962
+ def convert_tf_logical_not (scope , operator , container ):
963
+ oopb = OnnxOperatorBuilder (container , scope )
964
+ oopb .add_node_with_output ('And' ,
965
+ operator .input_full_names ,
966
+ operator .output_full_names ,
967
+ name = operator .full_name )
968
+
969
+
970
+ @converter_func (TYPES .LogicalNot )
971
+ def convert_tf_logical_not (scope , operator , container ):
972
+ oopb = OnnxOperatorBuilder (container , scope )
973
+ oopb .add_node_with_output ('Not' ,
974
+ operator .input_full_names ,
975
+ operator .output_full_names ,
976
+ name = operator .full_name )
977
+
978
+
943
979
@converter_func (TYPES .LogSoftmax )
944
980
def convert_tf_logsoftmax (scope , operator , container ):
945
981
oopb = OnnxOperatorBuilder (container , scope )
@@ -1713,6 +1749,32 @@ def convert_tf_not_equal(scope, operator, container):
1713
1749
name = operator .full_name + '_not' )
1714
1750
1715
1751
1752
+
1753
+ @converter_func (TYPES .OneHot )
1754
+ def convert_tf_one_hot (scope , operator , container ):
1755
+ if operator .target_opset < 9 :
1756
+ raise ValueError ("OneHot op is not supported until opset 9" )
1757
+ oopb = OnnxOperatorBuilder (container , scope )
1758
+ node = operator .raw_operator
1759
+ axis = node .get_attr ('axis' )
1760
+
1761
+ depth = oopb .apply_unsqueeze (operator .inputs [1 ].full_name ,
1762
+ name = operator .full_name + '_unsqueeze_1' ,
1763
+ axes = [0 ])
1764
+ on_value = oopb .apply_unsqueeze (operator .inputs [2 ].full_name ,
1765
+ name = operator .full_name + '_unsqueeze_2' ,
1766
+ axes = [0 ])
1767
+ off_value = oopb .apply_unsqueeze (operator .inputs [3 ].full_name ,
1768
+ name = operator .full_name + '_unsqueeze_3' ,
1769
+ axes = [0 ])
1770
+ off_on_value = oopb .apply_concat (off_value + on_value ,
1771
+ name = operator .full_name + '_concat' ,
1772
+ axis = 0 )
1773
+ oopb .add_node_with_output ('OneHot' , [operator .inputs [0 ].full_name ] + depth + off_on_value ,
1774
+ operator .output_full_names ,
1775
+ name = operator .full_name + '_one_hot' , axis = axis )
1776
+
1777
+
1716
1778
@converter_func (TYPES .ReadVariableOp )
1717
1779
def convert_tf_read_variable_op (scope , operator , container ):
1718
1780
oopb = OnnxOperatorBuilder (container , scope )
@@ -2096,6 +2158,8 @@ def convert_tf_zeros_like(scope, operator, container):
2096
2158
"Erf" : 9 ,
2097
2159
"Exp" : ("apply_exp" ,),
2098
2160
"Floor" : ("apply_floor" ,),
2161
+ "Greater" : ("apply_greater" ,),
2162
+ "Less" : ("apply_less" ,),
2099
2163
"Log" : ("apply_log" ,),
2100
2164
"Mul" : ("apply_mul" ,),
2101
2165
"Neg" : ("apply_neg" ,),
0 commit comments