-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
'eos_token_id' for llama model.generate is not working #24644
Comments
Hey! A few things to note:
Here is a working snippet: from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
weights_dir = "huggyllama/llama-7b"
question = 'Hello, there!'
# if you want to add eos, set `add_eos_token=True`
tokenizer = LlamaTokenizer.from_pretrained(weights_dir, add_eos_token=True)
question_ids = tokenizer.encode(question, return_tensors='pt')
print(question_ids)
# tensor([[ 1, 15043, 29892, 727, 29991, 2]])
print( tokenizer.decode(question_ids[0], skip_special_tokens = True))
# 'Hello, there!'
# if you are not using the correct version of tokenizer, special tokens are wrong
tokenizer = AutoTokenizer.from_pretrained(weights_dir, add_eos_token=True)
print(tokenizer.is_fast)
# True
question_ids = tokenizer.encode('Hello, there!</s>', return_tensors='pt')
print(question_ids)
# tensor([[ 1, 15043, 29892, 727, 29991, 829, 29879, 29958, 2]])
question_ids = tokenizer.encode('Hello, there! </s>', return_tensors='pt')
# tensor([[ 1, 15043, 29892, 727, 29991, 2, 2]])
print(question_ids) |
@ArthurZucker Many thanks! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
System Info
transformers
version: 4.30.2Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
No matter how I changing the parameters of model.generate, it always ignores the
</s>
as the ending token (id:2).In addition, the
skip_special_tokens
of tokenizer is not working too.Where am I doing wrong? Please help, many thanks!
Expected behavior
The
model.generate
stop at the first time of</s>
The text was updated successfully, but these errors were encountered: