Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtype mismatch in AttentiveStatisticsPooling with FP16 training mode #2544

Open
MM-0712 opened this issue May 9, 2024 · 0 comments
Open
Labels
bug Something isn't working

Comments

@MM-0712
Copy link

MM-0712 commented May 9, 2024

Describe the bug

attn = torch.cat([x, mean, std], dim=1)

If the model is trained with FP16 or BF16 mode, here will report dtype mismatch.
So, one solution is that it need add .to(x.dtype).

Expected behaviour

None

To Reproduce

None

Environment Details

No response

Relevant Log Output

No response

Additional Context

No response

@MM-0712 MM-0712 added the bug Something isn't working label May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant