Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt segmentation trainer to work with ViT #1403

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

nilsleh
Copy link
Collaborator

@nilsleh nilsleh commented Jun 6, 2023

No description provided.

@nilsleh nilsleh requested a review from isaaccorley June 6, 2023 13:20
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jun 6, 2023
@adamjstewart adamjstewart added this to In progress in SSL4EO-L via automation Jun 9, 2023
@adamjstewart
Copy link
Collaborator

I assume this will also be needed in PixelwiseRegressionTask?

@adamjstewart adamjstewart added this to the 0.4.2 milestone Jun 9, 2023
@github-actions github-actions bot added the scripts Training and evaluation scripts label Jun 9, 2023
@nilsleh
Copy link
Collaborator Author

nilsleh commented Jun 9, 2023

This is how I have hacked it so far.

@@ -93,7 +104,7 @@ def config_task(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)
self.model.encoder.load_state_dict(state_dict, strict=False)
Copy link
Collaborator Author

@nilsleh nilsleh Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done because for the tu-vit model self.model.encoder has a head.weight and bias.weight which is not something that a pretrained model has.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, the state dict of self.model.encoder is named model.{module} whereas resnet backbones for Unet are just {model}.

@@ -37,6 +37,9 @@ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]:
state_dict = OrderedDict(
{k.replace("model.", ""): v for k, v in state_dict.items()}
)
elif "vits16" in checkpoint["hyper_parameters"]["backbone"]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the state dict of self.model.encoder is named model.{module}. This naming is created in the conversion script.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very hyper-specific but I don't know enough to offer an alternative

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs shebang, copyright, executable

@@ -36,19 +36,30 @@ def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
weights = self.hyperparams["weights"]

if self.hyperparams["backbone"].startswith("tu-vit"):
encoder_depth = 4
decoder_channels = (256, 128, 64, 32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
decoder_channels = (256, 128, 64, 32)
decoder_channels: tuple[int, ...] = (256, 128, 64, 32)

This should placate mypy

@adamjstewart
Copy link
Collaborator

Don't know why tests are failing but we need to fix that

@adamjstewart adamjstewart changed the title Adapt segmentation trainer to work with VIT Adapt segmentation trainer to work with ViT Jun 9, 2023
@adamjstewart adamjstewart removed this from the 0.4.2 milestone Sep 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
scripts Training and evaluation scripts trainers PyTorch Lightning trainers
Projects
No open projects
SSL4EO-L
In progress
Development

Successfully merging this pull request may close these issues.

None yet

2 participants