Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Llama 3 Finetuned model is not generating EOS token. #416

Open
KillerShoaib opened this issue May 3, 2024 · 7 comments
Open

[FIXED] Llama 3 Finetuned model is not generating EOS token. #416

KillerShoaib opened this issue May 3, 2024 · 7 comments
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster

Comments

@KillerShoaib
Copy link

I've fine-tuned the llama 3 8 billion model. I followed the notebook and only changed the dataset. The dataset is similar to the alpaca dataset but for the Bangla language. I've trained the model for 1 epoch (36hrs) on a single T4 GPU. But, when I'm trying to generate a response it is not generating any eos token. It will go on till hitting the max_new_token length and stop.

Here is a sample of the code that is creating the dataset. (The same as the colab notebook. Just change the dataset name and system prompt)

code:

alpaca_prompt = """Below is an instruction in bangla that describes a task, paired with an input also in bangla that provides further context. Write a response in bangla that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
pass

from datasets import load_dataset
dataset = load_dataset("iamshnoo/alpaca-cleaned-bengali", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

One single example of the dataset['text'] looks like this:

'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nপদার্থের পরিবর্তনশীলতা এর বৈজ্ঞানিক সংজ্ঞা কি?\n\n### Input:\n\n\n### Response:\nবিপাক একটি জীবের মধ্যে ঘটে যাওয়া সমস্ত জৈব রাসায়নিক বিক্রিয়াকে বোঝায়, যার মধ্যে এমন প্রতিক্রিয়া রয়েছে যা শক্তি উত্পাদন করতে অণু ভাঙ্গতে পারে (ক্যাটাবলিজম) এবং নতুন অণু তৈরি করে (অ্যানাবলিজম) । এই প্রতিক্রিয়াগুলি এনজাইম দ্বারা সহজতর হয় এবং বৃদ্ধি, প্রজনন এবং পরিবেশের প্রতিক্রিয়া হিসাবে প্রয়োজনীয় প্রক্রিয়াগুলির মাধ্যমে জীবন বজায় রাখার জন্য প্রয়োজনীয়। বিপাক বিশেষত খাদ্যের ভাঙ্গন এবং এটি শক্তিতে রূপান্তরিত হতে পারে।<|end_of_text|>'

The EOS token has been added to the text in the end

Here is the generation code (same as the notebook):

# alpaca_prompt = Copied from above
alpaca_prompt = """Below is an instruction in bangla that describes a task, paired with an input also in bangla that provides further context. Write a response in bangla that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    alpaca_prompt.format(
        "সুস্থ থাকার তিনটি উপায় বলুন", # instruction
        "", # input
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 2048, use_cache = True)
tokenizer.batch_decode(outputs)

Here is the response output :

['<|begin_of_text|>Below is an instruction in bangla that describes a task, paired with an input also in bangla that provides further context. Write a response in bangla that appropriately completes the request.\n\n### Instruction:\nসুস্থ থাকার তিনটি উপায় বলুন\n\n### Input:\n\n\n### Response:\n১. নিয়মিত ব্যায়াম করুন: নিয়মিত শারীরিক ক্রিয়াকলাপ করা আপনার শরীরের স্বাস্থ্য এবং সুস্থতা বজায় রাখতে সহায়তা করতে পারে। এটি হার্ট রোগ, ডায়াবেটিস এবং স্থূলতার মতো দীর্ঘস্থায়ী রোগের ঝুঁকি হ্রাস করতে পারে। ২. স্বাস্থ্যকর খাদ্য খানঃ একটি সুষম এবং পুষ্টিকর ডায়েট খাওয়া আপনার শরীরের স্বাস্থ্য এবং সুস্থতা বজায় রাখতে সহায়তা করতে পারে। ফল, সবজি, পূর্ণ শস্য, চর্বিযুক্ত প্রোটিন এবং স্বাস্থ্যকর ফ্যাট সহ একটি ভারসাম্যপূর্ণ ডায়েট খাওয়া আপনার শরীরকে সঠিকভাবে কাজ করতে সহায়তা করতে পারে। ৩. পর্যাপ্ত ঘুম পানঃ পর্যাপ্ত ঘুম পাওয়া আপনার শরীরের স্বাস্থ্য এবং সুস্থতা বজায় রাখতে গুরুত্বপূর্ণ। প্রতি রাতে কমপক্ষে 7-8 ঘন্টা ঘুম পাওয়া আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। ঘুমের অভাব আপনার ইমিউন সিস্টেমকে দুর্বল করতে পারে, রোগের ঝুঁকি বাড়িয়ে তুলতে পারে এবং আপনার মানসিক স্বাস্থ্যের উপর নেতিবাচক প্রভাব ফেলতে পারে। সুতরাং পর্যাপ্ত ঘুম পাওয়া আপনার সামগ্রিক স্বাস্থ্য এবং সুস্থতা বজায় রাখতে গুরুত্বপূর্ণ। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য। এটি আপনার শরী']

I asked the model in Bangla "Tell me 3 ways I can be healthy" and the model generated a coherent response. But after finishing the response it starts spamming "এটি আপনার শরীরের স্বাস্থ্যের জন্য অপরিহার্য" (eng-translation: It is necessary for your body). And it goes till it hits the max_new_token length. I've tried different questions, but the result is always the same. I couldn't find a single time where the model generated the eos token.

The EOS token has been added to the data['text']. So in theory, If I fine-tune the model then it should learn to predict the EOS token. I've a total 51k samples and finetuned the model for 1 epoch.

One thing I've noticed is that in the original colab notebook, when the model was trained for 60 iterations and used to generate a response none of the responses generated EOS token.

@zifken
Copy link

zifken commented May 3, 2024

Did you consider using the llama3 chat template instead of the default one (check this notebook) ?
Alternatively you could use tools like guidance which offers a lot of options to stop generation (for example regex or substrings). However, you will need to convert your model to llama.cpp to use with guidance. You loose unsloth's inference speed up but you can run on cpu.

@DDCY220
Copy link

DDCY220 commented May 4, 2024

I encountered the same problem. I added EOS to the training data, but during prediction, the output always continues to the maximum number of tokens.

@mxtsai
Copy link

mxtsai commented May 4, 2024

I'm facing the same problem here.

@KillerShoaib
Copy link
Author

KillerShoaib commented May 5, 2024

I've figured out the solution. Below is the code for those who just want the solution, not the details:

Solution Code:

# change the padding tokenizer value
tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
model.config.pad_token_id = tokenizer.pad_token_id # updating model config
tokenizer.padding_side = 'right' # padding to right (otherwise SFTTrainer shows warning)

Now, pass the model and tokenizer to SFTTrainer.

Details of the solution

  1. I've come across a similar problem in this issue. It was for llama 2 model. A user pointed out that the pad_token_id & eos_token_id is the same. Therefore when the model is fine-tuning the loss function ignores both pad_token and eos_token. Thus, the model is not learning to predict the eos_token.
  2. Then I've checked the pad_token_id and eos_token_id for the unsloth-llama3. I found both the pad_token_id and the eos_token_id are the same.
print(f"Pad Token id: {tokenizer.pad_token_id} and Pad Token: {tokenizer.pad_token}")
print(f"EOS Token id: {tokenizer.eos_token_id} and EOS Token: {tokenizer.eos_token}")
>>> Pad Token id: 128001 and Pad Token: <|end_of_text|>
>>> EOS Token id: 128001 and EOS Token: <|end_of_text|> 
  1. Now that I've known that these 2 are the same. Well, all I've to do is change the pad_token_id. I've found this stack overflow question where it shows how to change the pad_token_id for falcon model.
  2. To change the pad_token_id you can not add any random value. It'll throw CUDA error. (I'm not sure but I'm assuming the reason for that error is the mismatch between tokenizer vocab size and model vocab size.)
  3. I've looked into the unsloth llama 3 model's tokenizer.json file and there are total 251 reserved special tokens. The values look like this <|reserved_special_token_0|> to <|reserved_special_token_250|>. You can use any of the reserved special token value as the pad_token value. I've used the first one <|reserved_special_token_0|>
  4. The code to change the value is written above.
  5. Now let's verify the pad_token_id and eos_token_id values.
print(f"Pad Token id: {tokenizer.pad_token_id} and Pad Token: {tokenizer.pad_token}")
print(f"EOS Token id: {tokenizer.eos_token_id} and EOS Token: {tokenizer.eos_token}")
>>> Pad Token id: 128002 and Pad Token: <|reserved_special_token_0|>
>>> EOS Token id: 128001 and EOS Token: <|end_of_text|>
  1. After that, I've trained my fine-tuned llama-3 model for just an extra 30 iterations with the newly changed pad_token_id. This time I ask the model same question as before and the model was able to generate eos_token and stopped before hitting the max_new_tokens length. Below I've shown 2 pictures showcasing the model's response for the same and different eos_token and pad_token.

pic1withoutEOS
pic2withEos

I'm hopping UnslothAI is going to see this bug and solve it in their colab notebook. Lots of people are facing this issue .

@danielhanchen
Copy link
Contributor

@KillerShoaib WHOOPS you are entirely correct!!!! I immediately updated all pad_tokens Unsloth has to <|reserved_special_token_250|> Thanks for the keen eye!!

@Nazzaroth2
Copy link

OMG Thank you for the solution here, was driving me nuts why llama3 was getting more rambling the more i trained it.

@danielhanchen danielhanchen added the fixed - pending confirmation Fixed, waiting for confirmation from poster label May 5, 2024
@KillerShoaib KillerShoaib changed the title Llama 3 Finetuned model is not generating EOS token. [FIXED] Llama 3 Finetuned model is not generating EOS token. May 5, 2024
@tdolega
Copy link

tdolega commented May 15, 2024

I suggest using <|end_of_text|> for pad token and <|eot_id|> for eos token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster
Projects
None yet
Development

No branches or pull requests

7 participants