Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix overwrite bug when adding symbol to dictionary #5329

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

lydianish
Copy link

@lydianish lydianish commented Sep 15, 2023

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #3064.
Fixes #3705.
Fixes #1309.

TLDR; This PR fixes the bug that duplicates the symbols that were meant to be overwritten in the vocabulary file. See detailed explanation in this blog post.

Expected behavior:

A Dictionary object has an indices dict and two lists (symbols and counts). By default, when loading a vocabulary from a file, a Dictionary instance is first created by adding 4 special tokens (<s>, <pad>, </s> and <unk> in that order). Then, all the entries from the file are appended to the Dictionary. If the vocabulary file already has some of the special tokens, their file entry should contain #fairseq:overwrite, otherwise a "duplicate" error will be raised at runtime. Furthermore, during preprocessing, the saved dictionary should not contain any of the special symbols.

Current behavior:

The add_symbol function is responsible for adding the symbols to the Dictionary. It has an overwrite argument that is set to True when the corresponding line in the file has #fairseq:overwrite. Rather than testing if word in self.indices and overwrite, it is currently testing if word in self.indices and not overwrite, which makes it ignore the case where the symbol should actually be overwritten. Hence, the symbol is appended to the symbols list, and its index is changed in the indices dict. This results in duplicate symbols and incorrect indices. Generally, only the special symbols will be affected. However, because the number of special tokens is set during initialization, it remains correct.

For example, a dictionary with 50K tokens that already has <s>, <pad> , </s> and <unk> with the #fairseq:overwrite tag will end up having 50004 tokens when loaded. This will also propagate to the subsequent model which will have an embedding dimension of 50004 instead of 50K. Also, with fairseq-preprocess, the resulting dictionary will skip the first 4 special symbols but will still contain the duplicate ones.

Domino effects and backward compatibility:

By fixing this bug, dictionary files will be loaded properly. However, this fix might cause problems in pipelines that use existing architectures and pretrained models because of the mismatch in sentencepiece encoding and/or embedding dimension.

For the sake of backward compatibility, a #fairseq:duplicate flag is introduced to ensure that duplicates are kept in the dictionary just like the bug. When used with fairseq-preprocess, the produced dict.txt file will also write #fairseq:duplicate next to the same symbols.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Yes, I did 🙃

This bug ignored the tokens that were meant to be overwritten and appends them to the end of the dictionary symbols.

For example, a dictionary with 50K tokens that already has `<s>`, `</s>`, `<pad>` and `<unk>` with the #fairseq:overwrite tag will end up having 50004 tokens when loaded.
Assert that overwrite works as expected (i.e. ignoring the duplicates)
@lydianish lydianish marked this pull request as ready for review September 21, 2023 16:20
@lydianish lydianish marked this pull request as draft September 21, 2023 17:03
For backward compatibility with the existing models/pipelines that uses a flawed dictionary loaded from file (before the bug fix)
@lydianish lydianish marked this pull request as ready for review September 21, 2023 20:33
@lydianish lydianish marked this pull request as draft September 21, 2023 21:40
…tionary

After fixing the behaviour of add_symbol, two of the unit tests were failing because they called the function with the default value of overwrite (False).
@lydianish lydianish marked this pull request as ready for review September 21, 2023 21:47
This ensures compatibility with all the calls to add_symbol across the repo (which overwrite by default, as in the original implementation). The only place where the value is explicitly changed is when loading the dictionary from file (which was the source of the bug). In a file you have to explicitly say whether the tokens should be overwritten or duplicated
@lydianish lydianish marked this pull request as draft March 8, 2024 12:48
@lydianish lydianish marked this pull request as ready for review March 8, 2024 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants