Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load Segmentation Trainer Weights #1379

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nilsleh
Copy link
Collaborator

@nilsleh nilsleh commented May 29, 2023

This PR changes the loading of a checkpoint in the segmentation model. Given that I have trained a model with a torchgeo trainer, I might want to do with the weights argument:

  • load the entire model (encoder and decoder)
  • load just a backbone
  • use only the model backbone from a trained model in a different task or at least make it available for other tasks

Not sure which of all should be supported by default, or whether there should be suggestions on how to do each of these (or other things).

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label May 29, 2023
@nilsleh nilsleh requested a review from isaaccorley May 29, 2023 15:53
@nilsleh nilsleh added this to In progress in SSL4EO-L via automation May 29, 2023
@adamjstewart
Copy link
Collaborator

Need to think about this one. I wonder if there isn't an easier way to tell Lightning to save the encoder in a separate file. Then it's trainer-dependent instead of having to if-statement every possible naming scheme in one function.

@adamjstewart adamjstewart added this to the 0.4.2 milestone May 29, 2023
@isaaccorley
Copy link
Collaborator

I'm not sure I understand. You are trying to load a model trained using which trainer into the segmentation trainer? Can you give the error message?

@isaaccorley
Copy link
Collaborator

If you need to load a checkpoint in the same trainer you can just use pytorch Lightning's loading method.

SemanticSegmentationTask.load_from_checkpoint(ckpt_path)

@adamjstewart
Copy link
Collaborator

I believe the checkpoint comes from MoCo and/or SimCLR and looks very different than the checkpoint expected by SemanticSegmentationTask.

@isaaccorley
Copy link
Collaborator

I don't think so. MoCo and SimCLR tasks named the backbone "backbone" not "encoder.

@isaaccorley
Copy link
Collaborator

isaaccorley commented May 29, 2023

Ah I see what's happening. Right now we don't support loading an entire pretrained segmentation checkpoint using the weights option. We only support loading the encoder part of the weights, not the decoder. So the solution would be to just load the checkpoint directly using pytorch lighting's checkpoint option.

Edit: so in this case, the PR changes are the correct solution.

@adamjstewart
Copy link
Collaborator

Is this still needed or is this superseded by #1403?

@isaaccorley
Copy link
Collaborator

I believe this is separate. This has to do with loading an entire UNet checkpoint using the weights argument. Right now the weights argument only loads the weights into the encoder (not the decoder).

@adamjstewart adamjstewart removed this from the 0.4.2 milestone Sep 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
No open projects
SSL4EO-L
In progress
Development

Successfully merging this pull request may close these issues.

None yet

3 participants