-
Notifications
You must be signed in to change notification settings - Fork 399
[Feature] Add RewriteCheckPointHook
to rewrite key of state_dict
in checkpoint
#357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
||
|
||
@HOOKS.register_module() | ||
class MigrateCheckPointHook(Hook): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to rename this file as well
removed and remapped keys. Removed keys has the next highest | ||
priority, once the original key has been removed, it cannot be mapped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed and remapped keys. Removed keys has the next highest | |
priority, once the original key has been removed, it cannot be mapped | |
removed and remapped keys. Removed keys have the second highest | |
priority. Once the original key has been removed, it cannot be mapped |
Args: | ||
applied_key (str): Target state dictionary saved in checkpoints, which | ||
needs to be overwritten. Defaults to "state_dict". | ||
removed_prefix (List[str]): Key starts with corresponding prefix will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed_prefix (List[str]): Key starts with corresponding prefix will | |
unused_prefix (str or List[str]): Keys starting with corresponding prefix(es) will |
removed_prefix
is ambiguous, as it sounds like the keys starting with these prefixes have already been removed.
two keys: ``src`` and ``dst``. ``src`` means the original key | ||
prefix and ``src`` means the target key prefix, see more | ||
information in examples. Defaults to []. | ||
merged_state_dicts (List[str]): A list of checkpoint paths need to be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merged_state_dicts (List[str]): A list of checkpoint paths need to be | |
merge_from (List[str]): A list of checkpoint paths needed to be |
Might be more concise
removed_prefix: List[str] = [], | ||
prefix_mapping: List[dict] = [], | ||
merged_state_dicts: List[str] = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hint: Union[List[str], str]
|
||
assert is_list_of( | ||
prefix_mapping, | ||
dict), ('prefix_mapping should be a list of dict a dict, but got ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict), ('prefix_mapping should be a list of dict a dict, but got ' | |
dict), ('prefix_mapping should be a list of dict or a dict, but got ' |
to other keys anymore. | ||
|
||
Args: | ||
applied_key (str): Target state dictionary saved in checkpoints, which |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied_key (str): Target state dictionary saved in checkpoints, which | |
state_key (str): Target state dictionary saved in checkpoints, which |
Not sure if it's a better name, but applied_key
sounds weird. What has been applied to the key?
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Add
MigrateCheckPointHook
to the rewriting key in the loaded checkpoint.The
state_dict
(or any other keys, such as ema_state_dict) incheckpoint
may not match the model strictly.MigrateCheckPointHook
can rewrite these keys to the matched ones.Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
Checklist