-
Notifications
You must be signed in to change notification settings - Fork 24.5k
Retain the parameter names in ONNX exporter #17551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
43910a1
to
481aa30
Compare
torch/onnx/utils.py
Outdated
@@ -44,7 +44,7 @@ def set_training(model, mode): | |||
|
|||
def export(model, args, f, export_params=True, verbose=False, training=False, | |||
input_names=None, output_names=None, aten=False, export_raw_ir=False, | |||
operator_export_type=None, opset_version=None): | |||
operator_export_type=None, opset_version=None, reserve_param_name=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor point - could we consider another name for this parameter, such as retain_param_name
or export_with_param_name
?
481aa30
to
d2465c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Why not make it a default right away?
torch/onnx/utils.py
Outdated
param_names = list(state_dict.keys()) | ||
for i in range(len(graph_inputs)): | ||
if i >= user_input_num: | ||
graph_inputs[i].setUniqueName( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if there's a duplicated name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir.cpp#L670, it will rename the old owner of this name. We should be good. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest let's always enable it (instead of having a flag to control)
torch/onnx/utils.py
Outdated
graph_inputs = list(graph.inputs()) | ||
user_input_num = len(graph_inputs) - len(state_dict) | ||
param_names = list(state_dict.keys()) | ||
for i in range(len(graph_inputs)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
for i, inp in enumerate(graph_inputs):
inp.setUniqueName(...)
if retain_param_name: | ||
graph_inputs = list(graph.inputs()) | ||
user_input_num = len(graph_inputs) - len(state_dict) | ||
param_names = list(state_dict.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, the order of the keys is not stable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm it's an OrderedDict
I will rebase this pr, and have another PR to enable it as default behavior. Don't want to let the massive updated expect files overwhelmed the important changes. I think to be safe, let's at least keep this flag for now. |
d2465c7
to
eba0a02
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
If you plan this flag as a temporary thing - start the name of the argument with underscore (i.e. _preserve_param_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Can you please revert all of the changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ship it
@sampepose - the .expect file changes are expected as we change the names of the parameters
graph_inputs = list(graph.inputs()) | ||
user_input_num = len(graph_inputs) - len(state_dict) | ||
param_names = list(state_dict.keys()) | ||
for i, inp in enumerate(graph_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or for inp, name in zip(graph_inputs[user_input_num:], param_names):
So, we will keep the names of ONNX initializers the same as the names in PyTorch state dict.
Later, we will make this as the default behavior.