Bad conversion of BatchNormalization for 2D input #92
Description
Hi
I have a keras model with a BatchNormalization layer between Dense layers.
The BatchNormalization layer has a parameter 'axis' that is default to -1.
It can be altered if the batchnorm is to be performed on another axis than the last one
example: Batchnorm after conv2d on 'channels_first' format, requires axis=1.
Important to note here that in the 2D case -1 and 1 are equal and the default -1 can be kept.
Anyway, the onnxruntime run command fails when axis=-1 (default) for 2D input.
The error I receive is:
RuntimeError: [ONNXRuntimeError] : 1 : GENERAL ERROR : [TypeInferenceError] Invalid attribute perm {0, 3, 1, 2}, input shape = {1, 10}
When the axis=1 (specified), the prediction goes fine.
I believe this is a keras2onnx issue and not an onnxruntime issue, but tell me if I'm wrong.
The issue, as I see it, is the converter assumes input dimension is 4D and the format should be channels_first, then for axis=1 it views the permutation as OK, and for axis=-1 it thinks the input should be permuted to channels_first and therefore it performs perm{0,3,1,2} and it doesn't check whether the input dimension is 4D or 2D.
code example
input_dim = 10
batch_size = 4
model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=(input_dim,)))
model.add(keras.layers.BatchNormalization(axis=-1))
model.add(keras.layers.Dense(5))
ximg = np.random.randn(batch_size , input_dim )
onnx_model = onnxmltools.convert_keras(model, target_opset=9)
sess = onnxruntime.InferenceSession(onnx_model.SerializeToString())
feed = dict([(input.name, ximg) for n, input in enumerate(sess.get_inputs())])
preds_onnx = sess.run(None, feed)[0]
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 10
ONNX installed from (source or binary): pip install keras2onnx onnxruntime
ONNX Runtime version: 0.4.0
keras2onnx version: 1.4.0
Python version: 3.6.0
Thnx