-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TextGeneration] max token refactor #1217
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The good: much less additional code and complexity that I thought
The bad and ugly: could you add appropriate tests in tests/deepsparse/transformers/pipelines/test_text_generation.py ?
d1d7b7a
to
735d33d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could max_tokens
default to the prompt_length - sequence_length
so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache
@dbogunowicz would like to get your opinion on this |
@dsikka this is a very good idea. |
- Remove max_generated_tokens from the constructor and add it to the TextGenerationInput Schema - Add num_generated_predictions to the TextGenerationInput which if > 1, repeats the input sequence and turns off deterministic prediction. If a sequence is already provided multiple times, the sequence is not repeated.
55d1c08
to
b5de75f
Compare
Talking to the MLE team, I think for now we want to keep the defaults as is and update them once we've established best practices. |
For this ticket:
https://app.asana.com/0/1201735099598270/1205276886236972/f
Summary:
max_tokens
argument and makes it part of the pipeline inputnum_generated_predictions
to the input as well, which dictates the number of sequences that are generated for a given input. Similar to the hugging face implementation, we repeat the input based on the number provided, defaulting to 1. Whennum_generated_predictions
is > 1, the engine'sdeterministic
property is togged to Falsestr, List[str], and List[List[str]]
Testing
num_generated_predictions