Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: FlashAttention only supports Ampere GPUs or newer. #307

Open
540627735 opened this issue Mar 21, 2024 · 1 comment
Open

RuntimeError: FlashAttention only supports Ampere GPUs or newer. #307

540627735 opened this issue Mar 21, 2024 · 1 comment

Comments

@540627735
Copy link

kaggle上运行lora微调的代码出现这样的报错:Traceback (most recent call last):
File "/kaggle/working/llamatest/train/sft/finetune_clm_lora.py", line 692, in
main()
File "/kaggle/working/llamatest/train/sft/finetune_clm_lora.py", line 653, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2902, in training_step
loss = self.compute_loss(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2925, in compute_loss
outputs = model(**inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1852, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1083, in forward
return self.base_model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1168, in forward
outputs = self.model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 997, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/opt/conda/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
outputs = run_function(*args)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 487, in forward
attn_output = self._flash_attention_forward(
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 537, in _flash_attention_forward
attn_output_unpad = flash_attn_varlen_func(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1059, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(args, kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 576, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 85, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only supports Ampere GPUs or newer.
Exception raised from mha_varlen_fwd at /home/runner/work/flash-attention/flash-attention/csrc/flash_attn/flash_api.cpp:519 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x6c (0x7f30afa8051c in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const
, char const
, unsigned int, char const
) + 0x84 (0x7f30afa35b04 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: mha_varlen_fwd(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor&, c10::optionalat::Tensor&, int, int, float, float, bool, bool, int, int, bool, c10::optionalat::Generator) + 0x10d7 (0x7f2fff7a46d7 in /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)

@Wszl
Copy link

Wszl commented Mar 30, 2024

启动脚本里面把use_flash_attn改成false就行了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants