Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I set per_device_train_batch_size=2, the S2-Attn would not shift as expected #182

Open
linhaojia13 opened this issue Mar 1, 2024 · 2 comments

Comments

@linhaojia13
Copy link

First, I ran the commands as follows:

CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=29501 supervised-fine-tune.py  \
        --model_name_or_path /mnt/42_store/lhj/data/mllm/model_weights/Llama-2-7b-chat-hf \
        --bf16 True \
        --output_dir outputs \
        --model_max_length 16384 \
        --use_flash_attn True \
        --data_path /data/mllm/LongAlpaca-16k-length/LongAlpaca-16k-length.json \
        --low_rank_training True \
        --num_train_epochs 5  \
        --per_device_train_batch_size 1     \
        --per_device_eval_batch_size 2     \
        --gradient_accumulation_steps 8     \
        --evaluation_strategy "no"     \
        --save_strategy "steps"     \
        --save_steps 98     \
        --save_total_limit 2     \
        --learning_rate 2e-5     \
        --weight_decay 0.0     \
        --warmup_steps 20     \
        --lr_scheduler_type "constant_with_warmup"     \
        --logging_steps 1     \
        --deepspeed "ds_configs/stage2.json" \
        --tf32 True

The variable cu_q_lens before flash_attn_varlen_qkvpacked_func is the as follows:

(Pdb) cu_q_lens
tensor([    0,  8192,  9243, 13339, 18486], device='cuda:0', dtype=torch.int32)
(Pdb) 

It seems OK, the group size is 8192 and the heads in the second half is shifted 4096 (13339-9243).

However, when I set per_device_train_batch_size=2, and run the command as follows:

CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=29501 supervised-fine-tune.py  \
        --model_name_or_path /mnt/42_store/lhj/data/mllm/model_weights/Llama-2-7b-chat-hf \
        --bf16 True \
        --output_dir outputs \
        --model_max_length 16384 \
        --use_flash_attn True \
        --data_path /data/mllm/LongAlpaca-16k-length/LongAlpaca-16k-length.json \
        --low_rank_training True \
        --num_train_epochs 5  \
        --per_device_train_batch_size 1     \
        --per_device_eval_batch_size 2     \
        --gradient_accumulation_steps 8     \
        --evaluation_strategy "no"     \
        --save_strategy "steps"     \
        --save_steps 98     \
        --save_total_limit 2     \
        --learning_rate 2e-5     \
        --weight_decay 0.0     \
        --warmup_steps 20     \
        --lr_scheduler_type "constant_with_warmup"     \
        --logging_steps 1     \
        --deepspeed "ds_configs/stage2.json" \
        --tf32 True

the variable cu_q_lens after function unpad_input is as follows:

(Pdb) cu_q_lens
tensor([    0,  9243, 25439, 34682, 50878], device='cuda:0', dtype=torch.int32)
(Pdb) 

In the end, the final cu_q_lens before flash_attn_varlen_qkvpacked_func is:

(Pdb) cu_q_lens
tensor([    0,  8192,  9243, 13339, 21531, 25439, 25439, 33631, 34682, 38778,
        46970, 50878], device='cuda:0', dtype=torch.int32)
(Pdb)

It is easy to see that the second half heads (25439, 33631, 34682, 38778, 46970, 50878) shares the same interval with the first half heads (0, 8192, 9243, 13339, 21531, 25439). In another words, the second half heads do not shift the group size.

Could you please fix this bug?

@linhaojia13
Copy link
Author

I modify llama_attn_replace_sft.py from

    x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
    cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
    cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
    cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
    cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
    cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
    cu_q_lens = cu_q_lens[cu_q_lens >= 0]

to

    x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
    cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
    cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
    cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
    # cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
    cu_q_len_tmp1p5 = cu_q_len_tmp + group_size
    cu_q_len_tmp1p5[cu_q_len_tmp1p5 >= max_s] = torch.iinfo(cu_q_len_tmp1p5.dtype).min
    cu_q_len_tmp = torch.cat([cu_q_len_tmp.unsqueeze(0), cu_q_len_tmp1p5.repeat(bsz-1, 1), cu_q_len_tmp2.repeat(bsz, 1)], dim=0) + cu_q_lens[:-1].unsqueeze(-1)
    cu_q_len_tmp[cu_q_len_tmp > cu_q_lens[1:].unsqueeze(-1)] = torch.iinfo(cu_q_len_tmp1p5.dtype).min
    cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
    cu_q_lens = cu_q_lens[cu_q_lens >= 0]

After modification, the final cu_q_lens before flash_attn_varlen_qkvpacked_func seems OK:

(Pdb) cu_q_lens
tensor([    0,  8192,  9243, 17435, 25439, 29535, 34682, 38778, 46970, 50878],
       device='cuda:0', dtype=torch.int32)
(Pdb) 

The second half heads did shift 4096 (29535-25439, 38778-34682).

It would be great if you could check whether my modifications are correct.

@yukang2017
Copy link
Member

Hi,

Thanks. I think your modification is right. Would you please have a check on other batch sizes, like 3 or 4?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants