@@ -1613,8 +1613,8 @@ def test_Bidirectional_seqlen_none(self):
1613
1613
expected = model .predict (x )
1614
1614
self .assertTrue (run_onnx_runtime (onnx_model .graph .name , onnx_model , x , expected , self .model_files ))
1615
1615
1616
- @unittest .skipIf (is_tf2 , 'TODO' )
1617
1616
def test_rnn_state_passing (self ):
1617
+ K .clear_session ()
1618
1618
for rnn_class in [SimpleRNN , GRU , LSTM ]:
1619
1619
input1 = Input (shape = (None , 5 ))
1620
1620
input2 = Input (shape = (None , 5 ))
@@ -1771,38 +1771,43 @@ def test_recursive_and_shared_model(self):
1771
1771
expected = keras_model .predict (x )
1772
1772
self .assertTrue (run_onnx_runtime ('recursive_and_shared' , onnx_model , x , expected , self .model_files ))
1773
1773
1774
- @unittest .skipIf (is_keras_older_than ("2.2.4" ) or is_tf_keras or is_tf2 ,
1774
+ @unittest .skipIf (is_keras_older_than ("2.2.4" ),
1775
1775
"Low keras version is not supported." )
1776
1776
def test_shared_model_2 (self ):
1777
1777
K .set_learning_phase (0 )
1778
1778
1779
- def _conv_layer (input , filters , kernel_size , strides = 1 , dilation_rate = 1 ):
1779
+ def _conv_layer (input , filters , kernel_size , relu_flag = False , strides = 1 , dilation_rate = 1 ):
1780
1780
padding = 'same' if strides == 1 else 'valid'
1781
1781
if strides > 1 :
1782
1782
input = ZeroPadding2D (((0 , 1 ), (0 , 1 )), data_format = K .image_data_format ())(input )
1783
1783
x = Conv2D (filters = filters , kernel_size = kernel_size , strides = strides ,
1784
1784
padding = padding , use_bias = False , dilation_rate = dilation_rate )(input )
1785
1785
ch_axis = 1 if K .image_data_format () == 'channels_first' else - 1
1786
1786
x = BatchNormalization (axis = ch_axis )(x )
1787
- return ReLU ()(x )
1787
+ if relu_flag :
1788
+ return ReLU ()(x )
1789
+ else :
1790
+ return x
1788
1791
1789
- def _model ():
1792
+ def _model (relu_flag = False ):
1790
1793
input = Input (shape = (3 , 320 , 320 ), name = 'input_1' )
1791
- x = _conv_layer (input , 16 , 3 )
1794
+ x = _conv_layer (input , 16 , 3 , relu_flag )
1792
1795
return Model (inputs = input , outputs = x , name = 'backbone' )
1793
1796
1794
- input = Input (shape = (3 , 320 , 320 ), name = 'input' )
1795
- backbone = _model ()
1796
- x = backbone (input )
1797
- x = _conv_layer (x , 16 , 3 )
1798
- model = Model (inputs = [input ], outputs = [x ])
1797
+ relu_flags = [False ] if is_tf2 or is_tf_keras else [True , False ]
1798
+ for relu_flag_ in relu_flags :
1799
+ input = Input (shape = (3 , 320 , 320 ), name = 'input' )
1800
+ backbone = _model (relu_flag_ )
1801
+ x = backbone (input )
1802
+ x = _conv_layer (x , 16 , 3 )
1803
+ model = Model (inputs = [input ], outputs = [x ])
1799
1804
1800
- onnx_model = keras2onnx .convert_keras (model , model .name )
1801
- x = np .random .rand (2 , 3 , 320 , 320 ).astype (np .float32 )
1802
- expected = model .predict (x )
1803
- self .assertTrue (run_onnx_runtime (onnx_model .graph .name , onnx_model , x , expected , self .model_files ))
1805
+ onnx_model = keras2onnx .convert_keras (model , model .name )
1806
+ x = np .random .rand (2 , 3 , 320 , 320 ).astype (np .float32 )
1807
+ expected = model .predict (x )
1808
+ self .assertTrue (run_onnx_runtime (onnx_model .graph .name , onnx_model , x , expected , self .model_files ))
1804
1809
1805
- @unittest .skipIf (is_keras_older_than ("2.2.4" ) or is_tf_keras ,
1810
+ @unittest .skipIf (is_keras_older_than ("2.2.4" ),
1806
1811
"ReLU support requires keras 2.2.4 or later." )
1807
1812
def test_shared_model_3 (self ):
1808
1813
def _bottleneck (x , filters , activation , strides , block_id ):
@@ -1851,7 +1856,8 @@ def convnet_7(input_shape, activation):
1851
1856
x = _bottleneck (x , filters = 32 , strides = 2 , activation = activation , block_id = 2 )
1852
1857
return Model (inputs = input , outputs = x , name = 'convnet_7' )
1853
1858
1854
- for activation in ['relu' , 'leaky' ]:
1859
+ activation_list = ['leaky' ] if is_tf2 or is_tf_keras else ['relu' , 'leaky' ]
1860
+ for activation in activation_list :
1855
1861
model = convnet_7 (input_shape = (3 , 96 , 128 ), activation = activation )
1856
1862
onnx_model = keras2onnx .convert_keras (model , model .name )
1857
1863
x = np .random .rand (1 , 3 , 96 , 128 ).astype (np .float32 )
0 commit comments