Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4khd-7b 多图sft时报错 #311

Open
zws-2019 opened this issue May 10, 2024 · 3 comments
Open

4khd-7b 多图sft时报错 #311

zws-2019 opened this issue May 10, 2024 · 3 comments
Assignees

Comments

@zws-2019
Copy link

我输入了两张图像,shape:
torch.Size([2, 3, 1680, 1008])

当我执行到:
self.vit([image], self.plora_glb_GN, self.plora_sub_GN)

报错:
RuntimeError: shape '[1, 3, 5, 336, 3, 336]' is invalid for input of size 10160640

用单张图片是不报错,两张时报错

@plmsmile
Copy link

同。
sub_img = img.reshape(1,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()

RuntimeError: shape '[1, 3, 3, 336, 4, 336]' is invalid for input of size 8128512

@plmsmile
Copy link

还有多图shape不一致的时候,需要resize到同一个shape才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一shape。

但还是会在build_mlp.py里出错。然后我又把 sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。

sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()

image

@zws-2019
Copy link
Author

还有多图形状不一致的时候,需要resize到同一个形状才可以。我是修改了data_mix.py里Sample_dataset里对多图做了统一形状。

但还是会在build_mlp.py里出错。然后我又把sub_image reshape的第一维改成cnt(单图是1,多图就是图片数量),后来就正常运行起来了。

sub_img = img.reshape(cnt,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contigious()

图像

这样看起来是可以跑通
4khd模型的处理逻辑看起来不支持多图
比如这里只把第一个image_feature作为glb_img,如果我有多图,逻辑就会有问题

        for [h, w] in shapes:
            B_ = h*w
            glb_img = image_features[:1] ### 1, N, C
            glb_img = glb_img.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
            temp_glb_GN = sub_GN.repeat(1, H//2, 1, 1)
            glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
            
            sub_img = image_features[1:1+B_] ### ?, N, C
            sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
            sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
            temp_sub_GN = sub_GN.repeat(1, h*12, 1, 1)
            sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)

            output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
            temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
            assert temp_len == output_imgs[-1].shape[1]
            output_len.append(temp_len)

            image_features = image_features[1+h*w:]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants