Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: "fused_dropout" not implemented for 'Byte' when running trl ppo finetuning #10854

Open
Jasonzzt opened this issue Apr 23, 2024 · 3 comments

Comments

@Jasonzzt
Copy link
Contributor

Machine: MAX1100
ipex-llm: 2.1.0b20240421
bigdl-core-xe-21 2.5.0b20240421
bigdl-core-xe-esimd-21 2.5.0b20240421

Related PR
When trying to run trl PPO finetuning on MAX1100, I got the following error.

(ppo) (base) wangyishuo@7cc25526b7ac:~/ziteng$ python ppo.py --model_name "/mnt/disk1/Llama-2-7b-chat-hf" --dataset_name "HuggingFaceH4/helpful_instructions"
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/datasets/load.py:1461: FutureWarning: The repository for HuggingFaceH4/helpful_instructions contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceH4/helpful_instructions
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
2024-04-22 19:34:28,707 - root - INFO - intel_extension_for_pytorch auto imported
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.15s/it]
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
2024-04-22 19:34:31,311 - root - INFO - peft adapter initialised
2024-04-22 19:34:31,315 - ipex_llm.transformers.utils - INFO - Converting the current model to fp4 format......
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at /mnt/disk1/Llama-2-7b-chat-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
0it [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
0it [00:02, ?it/s]
Traceback (most recent call last):
  File "/home/wangyishuo/ziteng/ppo.py", line 248, in <module>
    response_tensors = ppo_trainer.generate(
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 469, in generate
    response = self._generate_batched(
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 556, in _generate_batched
    generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 204, in generate
    return self.pretrained_model.generate(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/peft_model.py", line 1190, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/lookup.py", line 86, in generate
    return original_generate(self,
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/speculative.py", line 103, in generate
    return original_generate(self,
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 1520, in generate
    return self.sample(
           ^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 2617, in sample
    outputs = self(
              ^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 386, in forward
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/tuners/lora/layer.py", line 509, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
                                    ^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/dropout.py", line 58, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/functional.py", line 1266, in dropout
    return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "fused_dropout" not implemented for 'Byte'
@Uxito-Ada
Copy link
Contributor

@leonardozcm pls take a look, whether it is not supported by our kernel? tks.

@leonardozcm
Copy link
Contributor

hi, I think the VF.drop is not implemented by our kernels, instead I suppose this error indicates that input is in 8-bit data format which is not a supported dtype for torch.nn.functional.dropout

@Uxito-Ada
Copy link
Contributor

@Jasonzzt From the log, it is found that PPO also applies PEFT LoRA.
Therefore, like QLoRA, rather than from_pretrained a peft model with lora config, we should first load the base model, and then use get_peft_model, prepare_model_for_kbit_training etc. methods in qlora.py to create a peft model. Such a model is built on top of layers with supported operators like here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants