Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Falcon3 support and Fix issue #10875 #10883

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

mokeddembillel
Copy link
Contributor

This PR adds Falcon3 support and fixes issue #10875 caused by previous PR #10864 (see #10864 for details)

Details of fixing issue #10875:

The issue is that when using meta-llama/Llama-3.1-8B-Instruct the <|begin_of_text|> token is added to every special token when doing token = tokenizer.decode(tokenizer.encode(token))

the screenshot shows before and after token = tokenizer.decode(tokenizer.encode(token))
image

I'm fixing this by adding add_special_tokens=False to tokenizer.encode(). Here is the the result after the fix
image

to be extra safe, we will use token = tokenizer.decode(tokenizer.encode(token)) only if len(token) == 1 so that still fix this issue when \n is econded as Ċ

Generation before the fix:

Prompt: Once upon a time in a land far away,
there was a kingdom ruled by a wise and just king. The kingdom was known for its beauty and prosperity, and the people lived in peace and harmony.ĊĊOne day, a terrible drought struck the land, and the crops began to wither and die. The king, worried about the well-being of his people, called upon his wise council to find a solution. The council, after much deliberation, decided to send a group of brave knights to search for a magical spring that was said to have the power to bring rain to the kingdom.

Generation after the fix:

Prompt: Once upon a time in a land far away,
there was a kingdom ruled by a wise and just king. The kingdom was known for its beauty and prosperity, and the people lived in peace and harmony.

One day, a terrible drought struck the land, and the crops began to wither and die. The king, worried about the well-being of his people, called upon his wise council to find a solution. The council, after much deliberation, decided to send a group of brave knights to search for a magical spring that was said to have the power to bring rain to the kingdom.

@ggerganov @compilade @slaren

@github-actions github-actions bot added the python python script changes label Dec 18, 2024
@@ -525,6 +525,11 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# We need to manually encode and decode the added tokens in case special characters
# used for `\n` / `\t` have been manually added in the added tokens
if len(token) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(token) == 1:
# To avoid unexpected issues - we make sure to encode single-char tokens
if len(token) == 1:

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at the Falcon tokenizer and I don't see any added tokens that have \n or \t: https://huggingface.co/tiiuae/Falcon3-7B-Instruct/raw/main/tokenizer.json

For which tokens does this change make a difference?

Maybe also add some logs to know when this path is being triggered so we can spot any potential problems with other models.

Copy link
Contributor

@younesbelkada younesbelkada Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chiming in here ! The added token is

    {
      "id": 12,
      "content": "Ċ",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": false
    }

(\t is the id 13)
the only way to convert it properly to \n is to encode / decode using the tokenizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a log message inside the if statement.

# used for `\n` / `\t` have been manually added in the added tokens
# To avoid unexpected issues - we make sure to encode single-char tokens
if len(token) == 1:
logger.info("Ecode-Decode special characters using AutoTokenizer")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about comparing the token before and after the encoding and print the log only if there is a difference.

Copy link
Contributor Author

@mokeddembillel mokeddembillel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea. Done!

INFO:hf-to-gguf:'Ċ' is encoded and decoded back to '\n' using AutoTokenizer
INFO:hf-to-gguf:'ĉ' is encoded and decoded back to '\t' using AutoTokenizer

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems OK to me, but I am not sure about the full implications of this change for all other models. Want to wait for some feedback from the community.

The alternative is to find a way to apply this logic only inside the class FalconModel.

@ggerganov ggerganov requested a review from compilade December 18, 2024 08:41
@mokeddembillel
Copy link
Contributor Author

mokeddembillel commented Dec 18, 2024

Actually there's no FalconModel class and our model type is llama so we can't use that to check. The only solution I see is that we wait for some feedback from the community and if there's any error related to this, I will be happy to address it and fix it quickly.

@compilade
Copy link
Collaborator

compilade commented Dec 18, 2024

but I am not sure about the full implications of this change for all other models.

This can be tested by converting all tokenizers fetched by convert_hf_to_gguf_update.py and comparing the hashes when converted before and after this change (which I can't do right now, but will when I can).

I think what would solve variations of this problem for other models in the future (for another PR) would be to either normalize all added tokens which are marked "normalized": false, since the added tokens are internally assumed to be pre-normalized (this is the same problem which #8228 attempted to solve, but apparently the fix wasn't general enough (it only normalizes "▁" to " ", which solved this problem for Gemma)), or non-normalized added tokens could be internally handled by adding some token attribute for it. Though this would depend on proper support for token attributes stored in GGUF files, which isn't yet complete (even though per-token attributes were added in #7685, they aren't stored in GGUF models, and LLAMA_TOKEN_ATTR_NORMALIZED isn't really handled).

@younesbelkada
Copy link
Contributor

Thanks @compilade !
There might be an easier solution, I am about to manually modify the normalized characters (since the problem is only for \n and \t that have been explciitly added as special tokens) in the tokenizer file and push the normalized tokenizer on all repos - then we can convert this PR to simply adding the falcon3 pre-tokenizer - what do you think?

@compilade
Copy link
Collaborator

There might be an easier solution, I am about to manually modify the normalized characters (since the problem is only for \n and \t that have been explciitly added as special tokens) in the tokenizer file and push the normalized tokenizer on all repos - then we can convert this PR to simply adding the falcon3 pre-tokenizer - what do you think?

@younesbelkada

That could also work, as long as it's done correctly. The added tokens are in both tokenizers.json and tokenizer_config.json. If you do this, make sure that it doesn't have unintended consequences.

This is otherwise a nice edge case I think the convert scripts should have handled correctly, so part of me wants to keep the tokenizers the same.

@younesbelkada
Copy link
Contributor

Perfect thanks, will test that out and update here !

@younesbelkada
Copy link
Contributor

younesbelkada commented Dec 18, 2024

@compilade I just did some tests and I think we can't go with the solution I suggested above - mainly due to backward compatibility reasons - Before the manual changes

(Pdb) tok.encode("ĉ")
[13]

After the fix I suggested:

(Pdb) fixed_tokenizer.encode("ĉ")
[2150, 2237]

--> For the same token we now get different encodings - As all falcon3 series models have been trained with that tokenizer, even if the probability that this token appears in a text, I am afraid it's a way too risky breaking change to introduce..
I also tried to set normalize: true to these tokens and converted the model with this PR and still getting Ċ printed all over the place for line breaks.

Perhaps we can test if existing tokenizers are not affected by this PR, what do you think? Happy to help you on this as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants