-
Notifications
You must be signed in to change notification settings - Fork 107
Handle multiple dimension case for BatchNormalization #106
Conversation
keras2onnx/ke2onnx/batch_norm.py
Outdated
@@ -62,4 +69,4 @@ def convert_keras_batch_normalization(scope, operator, container): | |||
momentum=momentum, spatial=spatial) | |||
|
|||
# Permute [N,C,H,W] to [N,H,W,C] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify that this comment is for the 4D input shape case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, suggested a small descriptive change but it's not strictly necessary
perm_1 = list(range(1, len(op.input_shape) - 1)) | ||
perm_1 = [0, len(op.input_shape) - 1] + perm_1 | ||
perm_2 = list(range(2, len(op.input_shape))) | ||
perm_2 = [0] + perm_2 + [1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, might not ever need to calculate perm_1 and perm_2 if skip_transpose is False. Perhaps move these into the respective sections for converter efficiency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, will do
lock tf to 1.9 ... travis build having issues with 1.10
No description provided.