Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-3 offset-mapping needs fixing #1553

Open
davidb-cerebras opened this issue Jun 14, 2024 · 5 comments
Open

Llama-3 offset-mapping needs fixing #1553

davidb-cerebras opened this issue Jun 14, 2024 · 5 comments

Comments

@davidb-cerebras
Copy link

Opening a new issue for the previously opened issue here -- #1517

Here we can see that the desired behavior for return_offsets_mapping from Mistral gives character indices corresponding to tokens:

(Pdb) from transformers import AutoTokenizer
(Pdb) tok_mistral = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
(Pdb) tok_mistral(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[1, 27797, 2787]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 6), (6, 12)]]}
(Pdb) tok_mistral.convert_ids_to_tokens([1, 27797, 2787])
['<s>', '▁Sample', '▁input']
(Pdb) "Sample input"[0:6]
'Sample'
(Pdb) "Sample input"[6:12]
' input'

But for Llama-3 they are not correct

(Pdb) tok_llama3 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") 
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
(Pdb) tok_llama3(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[128000, 18031, 1988]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 0), (6, 6)]]}

We can also see Llama-2 and GPT-2 working the same as Mistral, so Llama-3 is definitely the one performing behavior that is unexpected

(Pdb) tok_llama2 = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
(Pdb) tok_llama2(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[1, 21029, 1881]], 'attention_mask': [[1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 6), (6, 12)]]}
(Pdb) tok_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2") 
(Pdb) tok_gpt2(["Sample input"], return_offsets_mapping=True)
{'input_ids': [[36674, 5128]], 'attention_mask': [[1, 1]], 'offset_mapping': [[(0, 6), (6, 12)]]}
@davidb-cerebras
Copy link
Author

@ArthurZucker Is it possible to fix this in tokenizers ?

@ArthurZucker
Copy link
Collaborator

Yep, you are right, I'll dive a bit to see why we have this!

@davidb-cerebras
Copy link
Author

Awesome thank you!

@maximilianmordig
Copy link

@ArthurZucker Is there a workaround in the meantime?

@ArthurZucker
Copy link
Collaborator

sorry not yet! I am fixing bunch of stuff, maybe #1568 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants