Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 45542a9

Browse files
authored
Handle multiple dimension case for BatchNormalization (#106)
* Handle multiple dimension case for BatchNormalization * Handle multiple dimension case for BatchNormalization
1 parent 3bf952a commit 45542a9

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

keras2onnx/ke2onnx/batch_norm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@
1010

1111
def convert_keras_batch_normalization(scope, operator, container):
1212
op = operator.raw_operator
13-
if (op.axis != 3 and op.axis != -1) or len(op.input_shape) == 2:
13+
skip_transpose = (op.axis != len(op.input_shape) - 1 and op.axis != -1) or len(op.input_shape) <= 2
14+
if not skip_transpose:
15+
perm_1 = list(range(1, len(op.input_shape) - 1))
16+
perm_1 = [0, len(op.input_shape) - 1] + perm_1
17+
perm_2 = list(range(2, len(op.input_shape)))
18+
perm_2 = [0] + perm_2 + [1]
19+
20+
if skip_transpose:
1421
adjusted_input_name = operator.inputs[0].full_name
1522
else:
1623
adjusted_input_name = scope.get_unique_variable_name(operator.inputs[0].full_name + '_transposed')
17-
apply_transpose(scope, operator.inputs[0].full_name, adjusted_input_name, container, perm=[0, 3, 1, 2])
24+
apply_transpose(scope, operator.inputs[0].full_name, adjusted_input_name, container, perm=perm_1)
1825

1926
input_tensor_names = [adjusted_input_name]
2027

@@ -49,7 +56,7 @@ def convert_keras_batch_normalization(scope, operator, container):
4956
momentum = op.momentum
5057
spatial = 1
5158

52-
if (op.axis != 3 and op.axis != -1) or len(op.input_shape) == 2:
59+
if skip_transpose:
5360
# If no transpose is required, we can simply use the output of ONNX BatchNorm as the final outcome
5461
apply_batch_norm(scope, input_tensor_names, operator.output_full_names, container,
5562
operator_name=operator.full_name, epsilon=epsilon, is_test=is_test,
@@ -61,5 +68,5 @@ def convert_keras_batch_normalization(scope, operator, container):
6168
operator_name=operator.full_name, epsilon=epsilon, is_test=is_test,
6269
momentum=momentum, spatial=spatial)
6370

64-
# Permute [N,C,H,W] to [N,H,W,C]
65-
apply_transpose(scope, intermediate_output_name, operator.outputs[0].full_name, container, perm=[0, 2, 3, 1])
71+
# For 4D case, this is to permute [N,C,H,W] to [N,H,W,C]
72+
apply_transpose(scope, intermediate_output_name, operator.outputs[0].full_name, container, perm=perm_2)

tests/test_layers.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -552,17 +552,38 @@ def test_batch_normalization(self):
552552
self._batch_norm_helper(data, 'zeros', 'zeros', False, True, 1)
553553

554554
def test_batch_normalization_2(self):
555-
# test batch normalization on 2D input
556-
input_dim = 10
557-
batch_size = 4
558-
model = keras.models.Sequential()
559-
model.add(keras.layers.InputLayer(input_shape=(input_dim,)))
560-
model.add(keras.layers.BatchNormalization(axis=-1))
561-
model.add(keras.layers.Dense(5))
562-
data = np.random.randn(batch_size, input_dim).astype(np.float32)
563-
onnx_model = keras2onnx.convert_keras(model)
564-
expected = model.predict(data)
565-
self.assertTrue(self.run_onnx_runtime('test_batch_normalization_2', onnx_model, [data], expected))
555+
for axis in [1, -1]:
556+
batch_size = 4
557+
input_dim_1 = 10
558+
input_dim_2 = 20
559+
input_dim_3 = 30
560+
561+
model = keras.models.Sequential()
562+
model.add(keras.layers.InputLayer(input_shape=(input_dim_1,)))
563+
model.add(keras.layers.BatchNormalization(axis=axis))
564+
model.add(keras.layers.Dense(5))
565+
data = np.random.randn(batch_size, input_dim_1).astype(np.float32)
566+
onnx_model = keras2onnx.convert_keras(model)
567+
expected = model.predict(data)
568+
self.assertTrue(self.run_onnx_runtime('test_batch_normalization_2_2d', onnx_model, [data], expected))
569+
570+
model = keras.models.Sequential()
571+
model.add(keras.layers.InputLayer(input_shape=(input_dim_1, input_dim_2)))
572+
model.add(keras.layers.BatchNormalization(axis=axis))
573+
model.add(keras.layers.Dense(5))
574+
data = np.random.randn(batch_size, input_dim_1, input_dim_2).astype(np.float32)
575+
onnx_model = keras2onnx.convert_keras(model)
576+
expected = model.predict(data)
577+
self.assertTrue(self.run_onnx_runtime('test_batch_normalization_2_3d', onnx_model, [data], expected))
578+
579+
model = keras.models.Sequential()
580+
model.add(keras.layers.InputLayer(input_shape=(input_dim_1, input_dim_2, input_dim_3)))
581+
model.add(keras.layers.BatchNormalization(axis=axis))
582+
model.add(keras.layers.Dense(5))
583+
data = np.random.randn(batch_size, input_dim_1, input_dim_2, input_dim_3).astype(np.float32)
584+
onnx_model = keras2onnx.convert_keras(model)
585+
expected = model.predict(data)
586+
self.assertTrue(self.run_onnx_runtime('test_batch_normalization_2_4d', onnx_model, [data], expected))
566587

567588
def test_simpleRNN(self):
568589
inputs1 = keras.Input(shape=(3, 1))

0 commit comments

Comments
 (0)