New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix keys name for Transformer #2529
base: develop
Are you sure you want to change the base?
Conversation
speechbrain/utils/checkpoints.py
Outdated
@@ -94,10 +131,13 @@ def torch_recovery(obj, path, end_of_epoch): | |||
""" | |||
del end_of_epoch # Unused | |||
device = "cpu" | |||
|
|||
state_dict = torch.load(path, map_location=device) | |||
state_dict = map_old_state_dict_weights(state_dict, KEYS_MAPPING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we have a more generic state_dict
fixup function that would call into map_old_state_dict_weights
, even if the general function only calls that one for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand your question/remark sorry haha. Could you please explain a bit more what you meant by general function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a function where we could add other state_dict
fixes in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you mean something like:
def foo(state_dict):
state_dict(state_dict)
so that someone can do :
def foo(state_dict):
bar(state_dict)
state_dict(state_dict)
easily?
mapping : dict | ||
A dictionary specifying the mapping between old and new keys. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is unclear the extent of what this is, in particular, to my understanding:
- This can affect any part of the key, which could look like
somemodule.somesubmodule.somefield
rather than justsomefield
- This will do partial matches on keys (e.g.
"a": "b"
will turn_a
into_b
)
The behavior should be documented to avoid accidentally including renames with too many side effects.
for checkpoint_name, attribute_name in mapping.items(): | ||
for full_checkpoint_name in list(state_dict.keys()): | ||
if checkpoint_name in full_checkpoint_name: | ||
full_attribute_name = full_checkpoint_name.replace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a warning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Either a warning or a log in debug mode.
What does this PR do?
This PR solves an issue due to #2489. Indeed, this PR modify the name of a key in decoder transformer
self.mutihead_attn
toself.multihead_attn
. Doing so breaks the loading of state dict since it does not recognize the previous key. In order to solve this problem, I introduce a new function calledmap_old_state_dict_weights
which is directly applied withintorch_recovery
,average_checkpoints
and_load_from_state_dict
. The first is when you are loading a checkpoint, the second when you are trying to avg multiple checkpoints (the issue here is that you need to make sure that before doing the avg, every ckpts has the same keys), and the latter when you are loading the state_dict directly from the object which can be the case in our codebase (i.e. bypassing checkpointer).Before submitting
PR review
Reviewer checklist