Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading checkpoints with newly added tokens. #272

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

charlesCXK
Copy link

No description provided.

@charlesCXK charlesCXK changed the title Add support for newly added tokens. Add support for loading checkpoints with newly added tokens. Mar 22, 2024
@danielhanchen
Copy link
Contributor

Wait would this load the lm_head and embed_tokens matrix correctly?

@danielhanchen
Copy link
Contributor

Would it not cause it to be randomnly inited?

@charlesCXK
Copy link
Author

Would it not cause it to be randomnly inited?

I have tested the code using such a setting:

  1. First I add new tokens to the tokenizer.
'''
########################################
Add special tokens to the tokenizer.
########################################
'''
if True:
    old_vocab_size = tokenizer.vocab_size
    print('old vocab size: ', old_vocab_size)
    tokenizer.add_tokens("<NEWTOKEN>", special_tokens=True)
    tokenizer.add_tokens("</NEWTOKEN>", special_tokens=True)

    # test case
    print(tokenizer.tokenize("This is an example with <NEWTOKEN> and </NEWTOKEN> token."))  

    # We resize the embeddings to avoid index errors.
    model.resize_token_embeddings(len(tokenizer))
    model.config.vocab_size = len(tokenizer)

    # average init the new token embeddings
    num_new_tokens = len(tokenizer) - old_vocab_size
    print("num_new_tokens:", num_new_tokens)
    input_embeddings = model.get_input_embeddings().weight.data
    output_embeddings = model.get_output_embeddings().weight.data
    input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
        dim=0, keepdim=True)
    output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
        dim=0, keepdim=True)
    input_embeddings[-num_new_tokens:] = input_embeddings_avg
    output_embeddings[-num_new_tokens:] = output_embeddings_avg

    # open lm head and input embedding
    model.lm_head.weight.requires_grad = True
    model.get_input_embeddings().weight.requires_grad = True
  1. I trained the model on a dataset with several steps and save the lora checkpoint.
save_path = "/home/xxx"
if os.path.exists(save_path):
    shutil.rmtree(save_path)
model.save_pretrained(save_path)
  1. Then I use the saved checkpoint for inference.
print('Use saved model for inference.')
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    new_token_num = 0,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    "Continue the fibonnaci sequence. 1, 1, 2, 3, 5, 8"
], return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)
  1. The output is the same as the original model.

@chtmp223
Copy link

chtmp223 commented Apr 4, 2024

Hi @charlesCXK, when using this code, I noticed that the loaded model doesn't include the new token that I added before fine-tuning. Do you have to add the new token again for inference? For example,

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    new_token_num = 1,        # 1 new added token
)
if "<pad>" not in tokenizer.get_vocab():
    tokenizer.add_tokens(["<pad>"], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))  

# Inference code goes here

@danielhanchen
Copy link
Contributor

Whoopsies sorry on the horrible delay - I'll review this PR and test it out - so sorry!

@danielhanchen
Copy link
Contributor

@charlesCXK @chtmp223 Extreme apologies on the delay - I think I might have fixed it. You need to call add_new_tokens before get_peft_model to update the vocab, resize, and also save the learnt embeddings

from unsloth import add_new_tokens
from unsloth import FastLanguageModel

add_new_tokens(model, tokenizer, ["new_token_1", "new_token_2"])
model = FastLanguageModel.get_peft_model(model, ...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants