Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble loading ViT - Dino structure for channels>3? #291

Open
AgentM-GEG opened this issue Dec 22, 2023 · 0 comments
Open

Trouble loading ViT - Dino structure for channels>3? #291

AgentM-GEG opened this issue Dec 22, 2023 · 0 comments

Comments

@AgentM-GEG
Copy link

AgentM-GEG commented Dec 22, 2023

Hi,

I am trying to do a ViT + Dino framework example illustrated in the repository, with slightly changed parameters (channels=4 and image size = 224). I found that the example works as expected when channels=3 and fails with a Runtime Error RuntimeError: Given normalized_shape=[4096], expected input with shape [*, 4096], but got input of size[2, 49, 3072]. I feel like something has been hardcoded within dino.py that is causing this issue. Please suggest any changes/recommendations. I feel like I am missing something obvious here.

EDIT: I think it may be because of this https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/dino.py#L249

my_model = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    channels=4
)

learner_model = Dino(
    my_model,
    image_size = 224,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 50176,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant