Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len' #5555

Open
alphanlp opened this issue Apr 4, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@alphanlp
Copy link

alphanlp commented Apr 4, 2024

馃悰 Describe the bug

File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/llmodel/huap/ColossalAI/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py", line 133, in attention_forward
cos, sin = self.rotary_emb(v, seq_len=kv_len)
File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/llmodel/miniconda3/envs/colossal/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() got an unexpected keyword argument 'seq_len'

Environment

python 3.10
transformers 4.39.2
colossalai 0..3.6

@alphanlp alphanlp added the bug Something isn't working label Apr 4, 2024
@Orion-Zheng
Copy link
Contributor

Hi馃槂This error is because since the transformers v4.39, the arguments seq_len is removed from LlamaRotaryEmbedding.forward(). But the code for ColossalLlama was written even further back (I guess it was around v4.34). At that time, the Flash Attention technique, which significantly speeds up attention and reduces memory consumption, had just come out and hadn't been integrated into LlamaAttention. That's why we need a flash_attn_patch to enable this feature back then. This patch is based on a function signature from an older version of Transformers.
image
But for now, the Flash Attention has already be integrated to Huggingface Llama Implementation(see classes LlamaFlashAttention2 and LlamaSdpaAttention). So I think you can just set use_flash_attn to False and Llama Model will automatically use the flash attention feature now. I believe later this patch will be removed.

@shawnricecake
Copy link

when I change transformer into 4.38.0, it shows

  File "/home/user1/workspace/colossal-ai/ColossalAI/examples/language/llama2/attn.py", line 133, in attention_forward
    cos, sin = self.rotary_emb(v, seq_len=kv_len)
  File "/home/user1/anaconda3/envs/colossalai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'

So, which version of transformer should I use with flash attention?

@Orion-Zheng
Copy link
Contributor

Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn to False. Because flash attention has been integrated to transformers library without needing our patch.

@shawnricecake
Copy link

Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn to False. Because flash attention has been integrated to transformers library without needing our patch.

Hi, looks like if I set the use_flash_attn to Flase, the GPU memory will increase.

and here is my env:

Package                   Version
------------------------- -----------
absl-py                   2.1.0
aiohttp                   3.9.3
aiosignal                 1.3.1
annotated-types           0.6.0
async-timeout             4.0.3
attrs                     23.2.0
bcrypt                    4.1.2
beautifulsoup4            4.12.3
cachetools                5.3.3
certifi                   2024.2.2
cffi                      1.16.0
cfgv                      3.4.0
charset-normalizer        3.3.2
click                     8.1.7
cmake                     3.29.0.1
colossalai                0.3.6
contexttimer              0.3.3
cryptography              42.0.5
datasets                  2.18.0
decorator                 5.1.1
Deprecated                1.2.14
dill                      0.3.8
distlib                   0.3.8
dropout-layer-norm        0.1
einops                    0.7.0
fabric                    3.2.2
filelock                  3.13.3
flash-attn                2.2.1
frozenlist                1.4.1
fsspec                    2024.2.0
fused-dense-lib           0.0.0
google                    3.0.0
google-auth               2.29.0
google-auth-oauthlib      1.0.0
grpcio                    1.62.1
huggingface-hub           0.22.2
identify                  2.5.35
idna                      3.6
invoke                    2.2.0
Jinja2                    3.1.3
jsonschema                4.21.1
jsonschema-specifications 2023.12.1
lit                       18.1.2
Markdown                  3.6
markdown-it-py            3.0.0
MarkupSafe                2.1.5
mdurl                     0.1.2
mpmath                    1.3.0
msgpack                   1.0.8
multidict                 6.0.5
multiprocess              0.70.16
networkx                  3.3
ninja                     1.11.1.1
nodeenv                   1.8.0
numpy                     1.26.4
nvidia-cublas-cu11        11.10.3.66
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu11    11.7.101
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu11    11.7.99
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu11  11.7.99
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu11         8.5.0.96
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu11         10.9.0.58
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu11        10.2.10.91
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu11      11.4.0.1
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu11      11.7.4.91
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu11          2.14.3
nvidia-nccl-cu12          2.19.3
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu11          11.7.91
nvidia-nvtx-cu12          12.1.105
oauthlib                  3.2.2
packaging                 24.0
pandas                    2.2.1
paramiko                  3.4.0
pip                       23.3.1
platformdirs              4.2.0
pre-commit                3.7.0
protobuf                  5.26.1
psutil                    5.9.8
pyarrow                   15.0.2
pyarrow-hotfix            0.6
pyasn1                    0.6.0
pyasn1_modules            0.4.0
pycparser                 2.22
pydantic                  2.6.4
pydantic_core             2.16.3
Pygments                  2.17.2
PyNaCl                    1.5.0
python-dateutil           2.9.0.post0
pytz                      2024.1
PyYAML                    6.0.1
ray                       2.10.0
referencing               0.34.0
regex                     2023.12.25
requests                  2.31.0
requests-oauthlib         2.0.0
rich                      13.7.1
rotary-emb                0.1
rpds-py                   0.18.0
rsa                       4.9
safetensors               0.4.2
sentencepiece             0.1.99
setuptools                68.2.2
six                       1.16.0
soupsieve                 2.5
sympy                     1.12
tensorboard               2.14.0
tensorboard-data-server   0.7.2
tokenizers                0.13.3
torch                     2.0.0
tqdm                      4.66.2
transformers              4.33.3
triton                    2.0.0
typing_extensions         4.11.0
tzdata                    2024.1
urllib3                   2.2.1
virtualenv                20.25.1
Werkzeug                  3.0.2
wheel                     0.41.2
wrapt                     1.16.0
xentropy-cuda-lib         0.1
xxhash                    3.4.1
yarl                      1.9.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants