Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda memory for 3D VIT #300

Open
JesseZZZZZ opened this issue Mar 20, 2024 · 2 comments
Open

Cuda memory for 3D VIT #300

JesseZZZZZ opened this issue Mar 20, 2024 · 2 comments

Comments

@JesseZZZZZ
Copy link

image
this 356GIB is a little stunning... I don't think I changed the original code enormously, so does anyone know that it is my mistake or the original itself needs such huge cuda memory? Thanks a lot !

lucidrains added a commit that referenced this issue May 2, 2024
@lucidrains
Copy link
Owner

lucidrains commented May 2, 2024

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

@JesseZZZZZ
Copy link
Author

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

Thank you so much! It does fix my problem to some extent!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants