Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended vocab tokenizer merging text into a single string without spaces while decoding #1501

Open
savanth14 opened this issue Apr 17, 2024 · 3 comments

Comments

@savanth14
Copy link

savanth14 commented Apr 17, 2024

@ArthurZucker @younesbelkada @Narsil @n1t0 I tried to add new vocab to the existing mistral tokenizer vocab using the add_tokens() method. Everything went fine till I tried the extended vocab tokenizer for decoding the encoded text. I found that in the decoded text, the spaces are completely missing and all the decoded tokens are merged into a single string. Can you please help me resolve this issue. Here's the sample code:

import sentencepiece as spm

sp = spm.SentencePieceProcessor(model_file='mistral_tok.model')
tokenizer1 = transformers.AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")

vocab = [sp.id_to_piece(idx) for idx in range(sp.get_piece_size())]

new_tokens = set(vocab) - set(tokenizer1.vocab.keys())

tokenizer1.add_tokens(list(new_tokens))
# output: 14756

print("After adding new tokens, length of mistral tokenizer:", len(tokenizer1))
# output: 46756

tel_text = "నేను బాగున్నాను. మీరు ఏలా ఉన్నారు?" # original text

mistral_encode_ids = tokenizer1.encode(tel_text)

mistral_decode_text = tokenizer1.decode(mistral_encode_ids, skip_special_tokens=True)

print(mistral_decode_text)

# output: నేనుబాగున్నాను.మీరుఏలాఉన్నారు? # decoded text with missing spaces

To dig further into the problem, I re-initialised the mistral tokenizer from its original checkpoint "mistralai/mistral-7b-v0.1". Then I added 3 manually defined random tokens to the tokenizer using the same add_tokens method. Now I used the extended vocab tokenizer to encode and decode some text and it worked fine. I mean, the decoded text has retained the spacing similar to the original random text. Here's the code for this experiment:

mistral_tok = AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")

new_tokens = ["yoyoyo", "xoxoxo", "z0z0z0"]

mistral_tok.add_tokens(list(new_tokens))

print("After adding new tokens, length of mistral tokenizer:", len(mistral_tok))

random_text = "yoyoyo xoxoxo z0z0z0!"

random_text_2 = "This is my new yoyoyo style xoxoxo of z0z0z0 writing!"

mistral_encode_ids = mistral_tok.encode(random_text)

mistral_decode_text = mistral_tok.decode(mistral_encode_ids, skip_special_tokens=True)

mistral_encode_ids_2 = mistral_tok.encode(random_text_2)

mistral_decode_text_2 = mistral_tok.decode(mistral_encode_ids_2, skip_special_tokens=True)

print(mistral_decode_text)
# output: yoyoyo xoxoxo z0z0z0! # decoded text with spacing intact

print(mistral_decode_text_2) 
# This is my new yoyoyo style xoxoxo of z0z0z0 writing! # decoded text with spacing intact

Where is the problem? Why is the extended vocab tokenizer not able to decode properly when using the vocab from a different tokenizer? On the contrary, it is able to decode properly when new tokens are added manually.

In addition, I used the train_new_from_iterator method and trained a new tokenizer based on the mistral tokenizer. Then I used the same approach as above to extend the vocab of the old tokenizer. When I used this extended vocab tokenizer for decoding, I observed that "some spaces are missing while some of the tokens are merged".

from datasets import load_dataset

from transformers import AutoTokenizer

# pick the model type
model_type = "mistralai/mistral-7b-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_type)

# Original vocab size.
print(len(tokenizer))
# Note the outputs are 100s indices which points to unknown tokens.
print(tokenizer("నేను బాగున్నాను. మీరు ఏలా ఉన్నారు?"))

dataset = load_dataset("ai4bharat/sangraha", data_files=["verified/tel/data-0.parquet"], split="train")
telugu_train = iter(dataset[i]['text'] for i in range(150000))

# Train a new tokenizer using the am_train and the old tokenizer object.
new_tokenizer = tokenizer.train_new_from_iterator(telugu_train, vocab_size=8000)

new_tokens = set(new_tokenizer.vocab.keys()) - set(tokenizer.vocab.keys())

tokenizer.add_tokens(list(new_tokens))

tel_text = "నేను బాగున్నాను. మీరు ఏలా ఉన్నారు?"

mistral_encode_ids = tokenizer.encode(tel_text)

mistral_decode_text = tokenizer.decode(mistral_encode_ids, skip_special_tokens=True)

new_encode_ids = new_tokenizer.encode(tel_text)

new_decode_text = new_tokenizer.decode(new_encode_ids, skip_special_tokens=True)

print("Length of telugu text: ", len(tel_text))
print('---')
print("Extended vocab mistral: ", mistral_encode_ids)
print(len(mistral_encode_ids))
print('---')
print("Extended vocab mistral: ", mistral_decode_text)
print('---')
print("New tokenizer trained on mistral: ", new_encode_ids)
print(len(new_encode_ids))
print('---')
print("New tokenizer trained on mistral: ", new_decode_text)

# output: Extended vocab mistral:  నేను బాగున్నాను.మీరు ఏలాఉన్నారు? # extended vocab tokenizer decoded text with some spaces missing

# output: New tokenizer trained on mistral:  నేను బాగున్నాను. మీరు ఏలా ఉన్నారు? # new tokenizer trained on existing mistral tokenizer with proper decoding

Can you please suggest me how to fix this issue.

@cheburakshu
Copy link

cheburakshu commented May 13, 2024

You have nicely demonstrated and answered your question here by training a new tokenizer on a new vocab. 👏

Train a new tokenizer using the am_train and the old tokenizer object.

@ArthurZucker
Copy link
Collaborator

Hey @savanth14 you should set normalized=False when you add the token.
I recommend you to set legacy=False to make sure you don't have these issues:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1", legacy=False, from_slow=True)
tokenizer.add_tokens([AddedToken("<mytoken>", normalized=False)])

Using main with huggingface/transformers#28881 merged will also help.
The issue with space is basically that if you normalize the token

@ArthurZucker
Copy link
Collaborator

FYI @itazap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants