Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail to run v2 with flash attention #140

Open
kenchanLOL opened this issue Mar 10, 2024 · 4 comments
Open

Fail to run v2 with flash attention #140

kenchanLOL opened this issue Mar 10, 2024 · 4 comments

Comments

@kenchanLOL
Copy link

I got following error message when I input a 2-mins long video with the default hyperparameter setting (beam search numbers = 1, temperature = 1, video segments = 8) and "Hi" as text input.

Input video shape: torch.Size([24, 224, 224])
n_position: 1568
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 8
Interpolate the position embedding
Traceback (most recent call last):
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/gradio/routes.py", line 408, in run_predict
    output = await app.get_blocks().process_api(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/gradio/blocks.py", line 1315, in process_api
    result = await self.call_function(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/gradio/blocks.py", line 1043, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2134, in run_sync_in_worker_thread
    return await future
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
  File "demo.py", line 84, in gradio_answer
    llm_message,llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)
  File "/home/ivc5/Ask-Anything/video_chat2/conversation.py", line 64, in answer
    outputs = self.model.llama_model.generate(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/peft/peft_model.py", line 1140, in generate
    outputs = self.base_model.generate(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/transformers/generation/utils.py", line 1485, in generate
    return self.sample(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/transformers/generation/utils.py", line 2524, in sample
    outputs = self(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 674, in forward
    outputs = self.model(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 563, in forward
    layer_outputs = decoder_layer(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 292, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ivc5/miniforge3/envs/videochat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 220, in forward
    qkv = torch.stack([query_states, key_states, value_states], dim=2)
RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 121, 128] at entry 1
@Andy1621
Copy link
Collaborator

Hi! Please try to set use_cache=False as here.

@kenchanLOL
Copy link
Author

kenchanLOL commented Mar 11, 2024

I added the use_cache = False at

outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)

and I get a new error message.

Exception has occurred: RuntimeError
shape '[-1, 125]' is invalid for input of size 126
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 515, in forward
    position_ids = position_ids.view(-1, seq_length).long()
  File "/home/ivc5/Ask-Anything/video_chat2/models/blip2/modeling_llama_mem.py", line 674, in forward
    outputs = self.model(
  File "/home/ivc5/Ask-Anything/video_chat2/conversation.py", line 64, in answer
    outputs = self.model.llama_model.generate(
  File "/home/ivc5/Ask-Anything/video_chat2/inference.py", line 68, in ask_questions
    llm_message, _, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)
  File "/home/ivc5/Ask-Anything/video_chat2/inference.py", line 82, in main
    results = ask_questions(chat, chat_state, img_list, questions)
  File "/home/ivc5/Ask-Anything/video_chat2/inference.py", line 91, in <module>
    main()
RuntimeError: shape '[-1, 125]' is invalid for input of size 126

It might because there is a shape mismatch between inputs_embeds (shape = [1, 125]) and attention_mask (shape = [1, 126])

@Andy1621
Copy link
Collaborator

Can you simply try not to use flash_attn when inferring?

@kenchanLOL
Copy link
Author

kenchanLOL commented Mar 12, 2024

yes, I was able to run the model without flash_attn. However, I am trying flash attention because I want a faster and more memory-efficient inference when using long prompts. Apart from trying flash_attn, I also tried to load the model into multiple GPU but it failed. I reported this in another open issue here . Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants