Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same #379

Closed
Echoxvf opened this issue May 7, 2024 · 5 comments
Labels
bug Something isn't working question Further information is requested

Comments

@Echoxvf
Copy link

Echoxvf commented May 7, 2024

During the beginning of training, I encountered an issue where the data type was inconsistent with the convolutional kernel type. How should this be resolved?
File "/data/scripts/train.py", line 254, in main
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/opensora/schedulers/iddpm/respace.py", line 98, in training_losses
return super().training_losses(self._wrap_model(model), *args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/opensora/schedulers/iddpm/gaussian_diffusion.py", line 768, in training_losses
model_output = model(x_t, t, **model_kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/opensora/schedulers/iddpm/respace.py", line 127, in call
return self.model(x, new_ts, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/colossalai/booster/plugin/low_level_zero_plugin.py", line 65, in forward
return super().forward(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/colossalai/interface/model.py", line 25, in forward
return self.module(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/opensora/models/stdit/stdit.py", line 276, in forward
x = self.x_embedder(x) # [B, N, C]
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/opensora/models/layers/blocks.py", line 121, in forward
x = self.proj(x) # (B C T H W)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 610, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same

@zhengzangw
Copy link
Collaborator

Could you provide more information on your command? I run the following and everything is ok.

torchrun --standalone --nproc_per_node 1 scripts/train.py configs/opensora-v1-1/train/stage1.py --data-path MY_DATA_PATH

I guess you do not correctly specify dtype in the config or not pull the latest repo.

@zhengzangw zhengzangw added help wanted Extra attention is needed question Further information is requested and removed help wanted Extra attention is needed labels May 9, 2024
@Echoxvf
Copy link
Author

Echoxvf commented May 9, 2024

Thank you for your response, I utilize the open-sora 1.0 training command
torchrun --nnodes=1 --nproc_per_node=1 scripts/train.py configs/opensora/train/16x256x256.py --data-path YOUR_CSV_PATH
The Open-Sora 1.1 Training command is OK.
Thank you again.

@TXacs
Copy link

TXacs commented May 11, 2024

I have the same problem, and even I checkout v1.1.0 to run the script, got same error.

My command:
torchrun --nproc-per-node=4 scripts/train.py configs/pixart/train/1x512x512.py --data-path CSV_FILE

The config is nothing changes.

@zhengzangw
Copy link
Collaborator

This issue is because we update the model's config according to Huggingface, and do not change the previous ones.

@zhengzangw zhengzangw added the bug Something isn't working label May 11, 2024
@TXacs
Copy link

TXacs commented May 11, 2024

Thanks a lot! I add the code into pixart.py, solved the problem! Not only stdit.py.

def forward(self, x, timestep, y, mask=None):
      """
      Forward pass of PixArt.
      x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
      t: (N,) tensor of diffusion timesteps
      y: (N, 1, 120, C) tensor of class labels
      """
      dtype = self.x_embedder.proj.weight.dtype
      x = x.to(dtype)
      timestep = timestep.to(dtype)
      y = y.to(dtype)

This issue is because we update the model's config according to Huggingface, and do not change the previous ones.

@Echoxvf Echoxvf closed this as completed May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants