Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: OOM during llama2 pretraining with flashattention and PP #5549

Open
insujang opened this issue Apr 3, 2024 · 3 comments
Open

[BUG]: OOM during llama2 pretraining with flashattention and PP #5549

insujang opened this issue Apr 3, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@insujang
Copy link
Contributor

insujang commented Apr 3, 2024

馃悰 Describe the bug

I understand that this error came out of flash attention software stack, but it seems there is no related issue except for #Dao-AILab/flash-attention#590, therefore I anyway open an issue here. This problem happens as well with flash-attn 2.0.5.

Using pp in HybidParallelPlugin (No-ZeRO) and flash attention together for Llama2 results in OOM

When I try to run examples/language/llama2/pretrain.py, adding padding back to inputs returns OOM. Without flashattention it works fine.

plugin = HybridParallelPlugin(tp_size=2, pp_size=2, # all the other args are the same as in the example)

Note that if you set pp_size=1 you will get cache only has 0 layers exception (#5410) even before facing OOM :) So there is another bug in llama2 forward with attention parallelism. Just a sidenote

PYTHONPATH=/path/to/colossalai/examples/language/llama2 torchrun --standalone --nproc-per-node 4 pretrain.py -p hybrid_parallel -a -g -x bf16 -o /tmp/llama_checkpoint
  File "/data/insujang/colossalai/examples/language/llama2/attn.py", line 174, in attention_forward
    q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
  File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 119, in unpad_input
    index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 17, in forward
    return torch.gather(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 59.67 GiB. GPU 1 has a total capacity of 44.35 GiB of which 34.03 GiB is free. Process 1325526 has 10.31 GiB memory in use. Of the allocated memory 9.76 GiB is allocated by PyTorch, and 40.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)

I think this might be related to the size of attention_mask, but not sure

# from flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask):
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # attention_mask.shape=torch.Size([1, 1, 4096, 4096]
    # indices.shape=torch.Size([15642705])
    ...
    return (
        index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), # Error here
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# index_first_axis calls IndexFirstAxis.forward()
class IndexFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, indices):
        ...
        return torch.gather(
            rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) # Error here
        ).reshape(-1, *other_shape)

where attention_mask is created here:

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None

I would appreicate it if you could try if this is reproducible and the reason.

Environment

4 48GB A40s
Pytorch 2.2.1 | CUDA 12.1
ColossalAI branch: feature/update-transformers
transformers 4.36.0
flash-attn 2.5.6

@insujang insujang added the bug Something isn't working label Apr 3, 2024
@insujang
Copy link
Contributor Author

insujang commented Apr 3, 2024

@wangbluo Could you please help me solve this issue? Thanks

@wangbluo
Copy link
Contributor

wangbluo commented Apr 3, 2024

@wangbluo Could you please help me solve this issue? Thanks

Hi, could you please offer the model size you use?

@insujang
Copy link
Contributor Author

insujang commented Apr 3, 2024

I used 7b configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants