Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keys name for Transformer #2529

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

Adel-Moumen
Copy link
Collaborator

@Adel-Moumen Adel-Moumen commented Apr 26, 2024

What does this PR do?

This PR solves an issue due to #2489. Indeed, this PR modify the name of a key in decoder transformer self.mutihead_attn to self.multihead_attn. Doing so breaks the loading of state dict since it does not recognize the previous key. In order to solve this problem, I introduce a new function called map_old_state_dict_weights which is directly applied within torch_recovery, average_checkpoints and _load_from_state_dict. The first is when you are loading a checkpoint, the second when you are trying to avg multiple checkpoints (the issue here is that you need to make sure that before doing the avg, every ckpts has the same keys), and the latter when you are loading the state_dict directly from the object which can be the case in our codebase (i.e. bypassing checkpointer).

Before submitting
  • Did you read the contributor guideline?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Does your code adhere to project-specific code style and conventions?

PR review

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified
  • Confirm that the changes adhere to compatibility requirements (e.g., Python version, platform)
  • Review the self-review checklist to ensure the code is ready for review

@Adel-Moumen Adel-Moumen marked this pull request as ready for review April 30, 2024 13:37
@Adel-Moumen Adel-Moumen self-assigned this Apr 30, 2024
@Adel-Moumen Adel-Moumen added bug Something isn't working important labels Apr 30, 2024
@Adel-Moumen Adel-Moumen added this to the v1.0.1 milestone Apr 30, 2024
@@ -94,10 +131,13 @@ def torch_recovery(obj, path, end_of_epoch):
"""
del end_of_epoch # Unused
device = "cpu"

state_dict = torch.load(path, map_location=device)
state_dict = map_old_state_dict_weights(state_dict, KEYS_MAPPING)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have a more generic state_dict fixup function that would call into map_old_state_dict_weights, even if the general function only calls that one for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your question/remark sorry haha. Could you please explain a bit more what you meant by general function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a function where we could add other state_dict fixes in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you mean something like:

def foo(state_dict):
	state_dict(state_dict)

so that someone can do :

def foo(state_dict):
	bar(state_dict)
	state_dict(state_dict)

easily?

Comment on lines +97 to +98
mapping : dict
A dictionary specifying the mapping between old and new keys.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear the extent of what this is, in particular, to my understanding:

  • This can affect any part of the key, which could look like somemodule.somesubmodule.somefield rather than just somefield
  • This will do partial matches on keys (e.g. "a": "b" will turn _a into _b)

The behavior should be documented to avoid accidentally including renames with too many side effects.

for checkpoint_name, attribute_name in mapping.items():
for full_checkpoint_name in list(state_dict.keys()):
if checkpoint_name in full_checkpoint_name:
full_attribute_name = full_checkpoint_name.replace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a warning here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Either a warning or a log in debug mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants