Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to load model from pretrained checkpoint #1475

Open
2 tasks
guarin opened this issue Jan 12, 2024 · 0 comments
Open
2 tasks

Add function to load model from pretrained checkpoint #1475

guarin opened this issue Jan 12, 2024 · 0 comments

Comments

@guarin
Copy link
Contributor

guarin commented Jan 12, 2024

We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:

from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url

model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
     if key.startswith("backbone."):
        new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}

Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.

TODO

  • Add function
  • Document function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

No branches or pull requests

1 participant