Skip to content

Retain the parameter names in ONNX exporter #17551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

houseroad
Copy link
Member

So, we will keep the names of ONNX initializers the same as the names in PyTorch state dict.

Later, we will make this as the default behavior.

@@ -44,7 +44,7 @@ def set_training(model, mode):

def export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None):
operator_export_type=None, opset_version=None, reserve_param_name=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor point - could we consider another name for this parameter, such as retain_param_name or export_with_param_name?

@houseroad houseroad changed the title Reserve the parameter names in ONNX exporter Retain the parameter names in ONNX exporter Mar 3, 2019
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Why not make it a default right away?

param_names = list(state_dict.keys())
for i in range(len(graph_inputs)):
if i >= user_input_num:
graph_inputs[i].setUniqueName(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there's a duplicated name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir.cpp#L670, it will rename the old owner of this name. We should be good. :-)

@bddppq bddppq self-requested a review March 13, 2019 23:04
Copy link
Contributor

@bddppq bddppq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest let's always enable it (instead of having a flag to control)

graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i in range(len(graph_inputs)):
Copy link
Contributor

@bddppq bddppq Mar 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

for i, inp in enumerate(graph_inputs):
  inp.setUniqueName(...)

if retain_param_name:
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, the order of the keys is not stable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm it's an OrderedDict

@houseroad
Copy link
Member Author

I will rebase this pr, and have another PR to enable it as default behavior. Don't want to let the massive updated expect files overwhelmed the important changes. I think to be safe, let's at least keep this flag for now.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dzhulgakov
Copy link
Collaborator

If you plan this flag as a temporary thing - start the name of the argument with underscore (i.e. _preserve_param_names)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@sampepose
Copy link
Contributor

Can you please revert all of the changed .expect files except for TestOperators.test_retain_param_name.expect?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ship it

@sampepose - the .expect file changes are expected as we change the names of the parameters

graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph_inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or for inp, name in zip(graph_inputs[user_input_num:], param_names):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants