You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Hi, I am encountering an issue when training a new tokenizer based on the 'meta-llama/Meta-Llama-3-8B' tokenizer. In particular, the tokenizer's post_processor ids are not being reset correctly. You can reproduce the bug by running the code below.
fromtransformersimportAutoTokenizerimportjson# Download the llama 3 tokenizeroriginal_tokenizer=AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
# Create a new tokenizer like the old tokenizer and train itnew_tokenizer=original_tokenizer.train_new_from_iterator(iter(['hello', 'world']), 1000)
# set the pad token on bothoriginal_tokenizer.pad_token_id=original_tokenizer.eos_token_idnew_tokenizer.pad_token_id=new_tokenizer.eos_token_id# try tokenizing with bothtext= ['hello world', 'how are you today?']
batch=original_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
print("Original bos_token_id", original_tokenizer.bos_token_id)
print("Original tokenizer input_ids:")
print(batch.input_ids)
print()
batch=new_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
print("New bos_token_id:", new_tokenizer.bos_token_id)
print("New tokenizer input_ids:")
print(batch.input_ids)
# print out the new tokenizer's postprocessing info to show that the bos token was not changedprint("New tokenizer post processing_info",
json.dumps(json.loads(new_tokenizer._tokenizer.to_str())['post_processor'], indent=2))
I believe this is caused because the train_new_from_iterator function does not handle the case where postprocessors are the type Sequence (as in they contain multiple postprocessors) in the code from that method:
The expected behavior is that the function train_new_from_iterator will properly overwrite the original special token ids in the fast tokenizer's Sequence postprocessor when the special token ids are different in the new tokenizer.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.36.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hi, I am encountering an issue when training a new tokenizer based on the 'meta-llama/Meta-Llama-3-8B' tokenizer. In particular, the tokenizer's post_processor ids are not being reset correctly. You can reproduce the bug by running the code below.
This outputs the following:
The expected output is that the bos token id of 128000 is changed to the new bos token of 0 like in the following:
I believe this is caused because the train_new_from_iterator function does not handle the case where postprocessors are the type Sequence (as in they contain multiple postprocessors) in the code from that method:
transformers/src/transformers/tokenization_utils_fast.py
Lines 793 to 813 in c48787f
Thanks in advance for the help!
Expected behavior
The expected behavior is that the function train_new_from_iterator will properly overwrite the original special token ids in the fast tokenizer's Sequence postprocessor when the special token ids are different in the new tokenizer.
The text was updated successfully, but these errors were encountered: