Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Use of Nvidia Transformer Engine #1858

Closed
yazdanimehdi opened this issue Jun 21, 2023 · 1 comment
Closed

[FEATURE] Use of Nvidia Transformer Engine #1858

yazdanimehdi opened this issue Jun 21, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@yazdanimehdi
Copy link

Using https://github.com/NVIDIA/TransformerEngine to speed up transformer based models on new Nvidia Hopper GPUs and float8 training.

Ideally it will detect that you are using Ada based GPUs and adapt the transformer engine


@yazdanimehdi yazdanimehdi added the enhancement New feature or request label Jun 21, 2023
@rwightman
Copy link
Collaborator

@yazdanimehdi finally got around to picking up a 4090. It's nice and decent boost when using torchcompile.

I tried fiddling with Transformer Engine and FP8 autocast and it wasn't very helpful. I feel it needs re-writing the models to use fused layers and fully integrate the attention. Just doing the 'easy' bits such as converting nn.Linear and nn.LayerNorm and using te.autocast is slower than using torch native AMP w/ F.sdpa + bfloat16. The te the attention won't be using a fast kernel, and some of the matmuls won't be cast to lower precision, cannot combine torch autocast w/ te autocast it seems.

So, until torch decides to include some ada/hopper compatible FP8 support & casting + optimized kernels for e.g F.sdpa, I don't think there is much point, I am not going to maintain multiple versions of various blocks / models, etc with TE vs not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants