@@ -552,17 +552,38 @@ def test_batch_normalization(self):
552
552
self ._batch_norm_helper (data , 'zeros' , 'zeros' , False , True , 1 )
553
553
554
554
def test_batch_normalization_2 (self ):
555
- # test batch normalization on 2D input
556
- input_dim = 10
557
- batch_size = 4
558
- model = keras .models .Sequential ()
559
- model .add (keras .layers .InputLayer (input_shape = (input_dim ,)))
560
- model .add (keras .layers .BatchNormalization (axis = - 1 ))
561
- model .add (keras .layers .Dense (5 ))
562
- data = np .random .randn (batch_size , input_dim ).astype (np .float32 )
563
- onnx_model = keras2onnx .convert_keras (model )
564
- expected = model .predict (data )
565
- self .assertTrue (self .run_onnx_runtime ('test_batch_normalization_2' , onnx_model , [data ], expected ))
555
+ for axis in [1 , - 1 ]:
556
+ batch_size = 4
557
+ input_dim_1 = 10
558
+ input_dim_2 = 20
559
+ input_dim_3 = 30
560
+
561
+ model = keras .models .Sequential ()
562
+ model .add (keras .layers .InputLayer (input_shape = (input_dim_1 ,)))
563
+ model .add (keras .layers .BatchNormalization (axis = axis ))
564
+ model .add (keras .layers .Dense (5 ))
565
+ data = np .random .randn (batch_size , input_dim_1 ).astype (np .float32 )
566
+ onnx_model = keras2onnx .convert_keras (model )
567
+ expected = model .predict (data )
568
+ self .assertTrue (self .run_onnx_runtime ('test_batch_normalization_2_2d' , onnx_model , [data ], expected ))
569
+
570
+ model = keras .models .Sequential ()
571
+ model .add (keras .layers .InputLayer (input_shape = (input_dim_1 , input_dim_2 )))
572
+ model .add (keras .layers .BatchNormalization (axis = axis ))
573
+ model .add (keras .layers .Dense (5 ))
574
+ data = np .random .randn (batch_size , input_dim_1 , input_dim_2 ).astype (np .float32 )
575
+ onnx_model = keras2onnx .convert_keras (model )
576
+ expected = model .predict (data )
577
+ self .assertTrue (self .run_onnx_runtime ('test_batch_normalization_2_3d' , onnx_model , [data ], expected ))
578
+
579
+ model = keras .models .Sequential ()
580
+ model .add (keras .layers .InputLayer (input_shape = (input_dim_1 , input_dim_2 , input_dim_3 )))
581
+ model .add (keras .layers .BatchNormalization (axis = axis ))
582
+ model .add (keras .layers .Dense (5 ))
583
+ data = np .random .randn (batch_size , input_dim_1 , input_dim_2 , input_dim_3 ).astype (np .float32 )
584
+ onnx_model = keras2onnx .convert_keras (model )
585
+ expected = model .predict (data )
586
+ self .assertTrue (self .run_onnx_runtime ('test_batch_normalization_2_4d' , onnx_model , [data ], expected ))
566
587
567
588
def test_simpleRNN (self ):
568
589
inputs1 = keras .Input (shape = (3 , 1 ))
0 commit comments