-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt segmentation trainer to work with ViT #1403
base: main
Are you sure you want to change the base?
Adapt segmentation trainer to work with ViT #1403
Conversation
I assume this will also be needed in PixelwiseRegressionTask? |
This is how I have hacked it so far. |
@@ -93,7 +104,7 @@ def config_task(self) -> None: | |||
_, state_dict = utils.extract_backbone(weights) | |||
else: | |||
state_dict = get_weight(weights).get_state_dict(progress=True) | |||
self.model.encoder.load_state_dict(state_dict) | |||
self.model.encoder.load_state_dict(state_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is done because for the tu-vit
model self.model.encoder
has a head.weight
and bias.weight
which is not something that a pretrained model has.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, the state dict of self.model.encoder
is named model.{module}
whereas resnet backbones for Unet are just {model}
.
@@ -37,6 +37,9 @@ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: | |||
state_dict = OrderedDict( | |||
{k.replace("model.", ""): v for k, v in state_dict.items()} | |||
) | |||
elif "vits16" in checkpoint["hyper_parameters"]["backbone"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the state dict of self.model.encoder
is named model.{module}
. This naming is created in the conversion script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems very hyper-specific but I don't know enough to offer an alternative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs shebang, copyright, executable
@@ -36,19 +36,30 @@ def config_task(self) -> None: | |||
"""Configures the task based on kwargs parameters passed to the constructor.""" | |||
weights = self.hyperparams["weights"] | |||
|
|||
if self.hyperparams["backbone"].startswith("tu-vit"): | |||
encoder_depth = 4 | |||
decoder_channels = (256, 128, 64, 32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decoder_channels = (256, 128, 64, 32) | |
decoder_channels: tuple[int, ...] = (256, 128, 64, 32) |
This should placate mypy
Don't know why tests are failing but we need to fix that |
No description provided.