Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Bad conversion of BatchNormalization for 2D input #92

Closed
@yanivbenny

Description

@yanivbenny

Hi

I have a keras model with a BatchNormalization layer between Dense layers.
The BatchNormalization layer has a parameter 'axis' that is default to -1.
It can be altered if the batchnorm is to be performed on another axis than the last one
example: Batchnorm after conv2d on 'channels_first' format, requires axis=1.
Important to note here that in the 2D case -1 and 1 are equal and the default -1 can be kept.

Anyway, the onnxruntime run command fails when axis=-1 (default) for 2D input.
The error I receive is:
RuntimeError: [ONNXRuntimeError] : 1 : GENERAL ERROR : [TypeInferenceError] Invalid attribute perm {0, 3, 1, 2}, input shape = {1, 10}

When the axis=1 (specified), the prediction goes fine.

I believe this is a keras2onnx issue and not an onnxruntime issue, but tell me if I'm wrong.
The issue, as I see it, is the converter assumes input dimension is 4D and the format should be channels_first, then for axis=1 it views the permutation as OK, and for axis=-1 it thinks the input should be permuted to channels_first and therefore it performs perm{0,3,1,2} and it doesn't check whether the input dimension is 4D or 2D.

code example

input_dim = 10
batch_size = 4
model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=(input_dim,)))
model.add(keras.layers.BatchNormalization(axis=-1))
model.add(keras.layers.Dense(5))

ximg = np.random.randn(batch_size , input_dim )

onnx_model = onnxmltools.convert_keras(model, target_opset=9)
sess = onnxruntime.InferenceSession(onnx_model.SerializeToString())
feed = dict([(input.name, ximg) for n, input in enumerate(sess.get_inputs())])
preds_onnx = sess.run(None, feed)[0]

System information

OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 10
ONNX installed from (source or binary): pip install keras2onnx onnxruntime
ONNX Runtime version: 0.4.0
keras2onnx version: 1.4.0
Python version: 3.6.0

Thnx

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions