Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems regarding training 3D Vision transformer : model does not converge #296

Open
Uljibuh opened this issue Feb 6, 2024 · 0 comments

Comments

@Uljibuh
Copy link

Uljibuh commented Feb 6, 2024

Hi, this is my first time working on a transformer model, in this case, a 3D vision transformer model,

I am working on a 3d medical image classification task, and the training set is around 300 3D images; here is what image input looks like (1, 224, 224, 32); here, 1 is the number of channels, and 32 is the z dim size. I trained my data set on 3D efficientnet, and the accuracy was around 80%. I tried a 3D vision transformer, but the model does not converge. Can you please review the code below? Why does the model not learn? Do you know if I am doing something wrong? Do you have any help or suggestions? Thank you in advance.

This is the forward path:

`

    def forward(self, img):
     print("img,input shape before patch embedding", img.shape)

    x = self.to_patch_embedding(img)
    print("after patch embedding", x.shape)
    
    b, n, _ = x.shape
    #cls_tokens = self.cls_token.expand(b, -1, -1)
    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    
    print("cls token shape", cls_tokens.shape)
    x = torch.cat((cls_tokens, x), dim=1)
     
    print("after cls_token", x.shape)
    x += self.pos_embedding[:, :(n + 1)]
    
    print("after position embedding", x.shape)
    x = self.dropout(x)
    
    x = self.transformer(x)
    print("after transformer", x.shape)
    
    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    x = self.to_latent(x)
    
    print("after latent", x.shape)
    return self.mlp_head(x) `

3D vision transformer model configuration:

`

    ViTmodel = ViT(
    image_size = 224,
    image_patch_size = 16,
    frames = 32,
    frame_patch_size= 4,
    
    num_classes = 2,
    dim = 1024,
    depth = 6,
    heads = 2,
    mlp_dim = 1024,
    pool = 'cls',
    channels = 1,
    dim_head = 64,
    dropout = 0.2,
    emb_dropout = 0.1

)`

`
optimizer = optim.Adam(ViTmodel.parameters(), lr=0.002)

    criterion = nn.CrossEntropyLoss().to(device)

`

Input shapes from the forward path:

`

      img,input shape before patch embedding torch.Size([4, 1, 224, 224, 32])

       after patch embedding torch.Size([4, 1568, 1024])

       cls token shape torch.Size([4, 1, 1024])

       after cls_token torch.Size([4, 1569, 1024])

       after position embedding torch.Size([4, 1569, 1024])

       after transformer torch.Size([4, 1569, 1024])

       after latent torch.Size([4, 1024])

       Input shape: torch.Size([4, 1, 224, 224, 32])

       Output shape: torch.Size([4, 2])`

model training results:

`

Epoch 1/10 (Training): 100%|██████████| 56/56 [00:59<00:00, 1.07s/it]
Epoch 1/10, Training Loss: 0.5908904586519513, Training Accuracy: 0.7142857142857143
Epoch 1/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 1/10, Validation Loss: 0.5275474616459438, Validation Accuracy: 0.7798165137614679
Best model saved at epoch 1
Epoch 2/10 (Training): 100%|██████████| 56/56 [00:58<00:00, 1.04s/it]
Epoch 2/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 2/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 2/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 3/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 3/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 3/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 3/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679
Epoch 4/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 4/10, Training Loss: 0.5878153358186994, Training Accuracy: 0.7210884353741497
Epoch 4/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 4/10, Validation Loss: 0.5329046036515918, Validation Accuracy: 0.7798165137614679
Epoch 5/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 5/10, Training Loss: 0.6034403315612248, Training Accuracy: 0.7210884353741497
Epoch 5/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 5/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 6/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 6/10, Training Loss: 0.5878153379474368, Training Accuracy: 0.7210884353741497
Epoch 6/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 6/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant