diff --git a/keras2onnx/ke2onnx/batch_norm.py b/keras2onnx/ke2onnx/batch_norm.py index fdcb9037..45728ee5 100644 --- a/keras2onnx/ke2onnx/batch_norm.py +++ b/keras2onnx/ke2onnx/batch_norm.py @@ -10,7 +10,7 @@ def convert_keras_batch_normalization(scope, operator, container): op = operator.raw_operator - if op.axis != 3 and op.axis != -1: + if (op.axis != 3 and op.axis != -1) or len(op.input_shape) == 2: adjusted_input_name = operator.inputs[0].full_name else: adjusted_input_name = scope.get_unique_variable_name(operator.inputs[0].full_name + '_transposed') @@ -49,7 +49,7 @@ def convert_keras_batch_normalization(scope, operator, container): momentum = op.momentum spatial = 1 - if op.axis != 3 and op.axis != -1: + if (op.axis != 3 and op.axis != -1) or len(op.input_shape) == 2: # If no transpose is required, we can simply use the output of ONNX BatchNorm as the final outcome apply_batch_norm(scope, input_tensor_names, operator.output_full_names, container, operator_name=operator.full_name, epsilon=epsilon, is_test=is_test, diff --git a/tests/test_layers.py b/tests/test_layers.py index f0aa0eb1..e2fce174 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -551,6 +551,19 @@ def test_batch_normalization(self): self._batch_norm_helper(data, 'ones', 'ones', True, False, 1) self._batch_norm_helper(data, 'zeros', 'zeros', False, True, 1) + def test_batch_normalization_2(self): + # test batch normalization on 2D input + input_dim = 10 + batch_size = 4 + model = keras.models.Sequential() + model.add(keras.layers.InputLayer(input_shape=(input_dim,))) + model.add(keras.layers.BatchNormalization(axis=-1)) + model.add(keras.layers.Dense(5)) + data = np.random.randn(batch_size, input_dim).astype(np.float32) + onnx_model = keras2onnx.convert_keras(model) + expected = model.predict(data) + self.assertTrue(self.run_onnx_runtime('test_batch_normalization_2', onnx_model, [data], expected)) + def test_simpleRNN(self): inputs1 = keras.Input(shape=(3, 1)) cls = keras.layers.SimpleRNN(2, return_state=False, return_sequences=True)