Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEFAULT_MASK_VALUE causes gradient explosion and nan loss on deep models #614

Open
logicchains opened this issue Apr 23, 2024 · 1 comment
Assignees

Comments

@logicchains
Copy link

I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.

The gradient explosion seemed to be coming from
local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.

Changing

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
to
DEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.

Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.

@rwitten
Copy link
Collaborator

rwitten commented Apr 30, 2024

@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.

@anfals please be aware of this as you do convergence testing on GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants