Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] max token refactor #1217

Merged
merged 14 commits into from
Sep 12, 2023
Merged

[TextGeneration] max token refactor #1217

merged 14 commits into from
Sep 12, 2023

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Aug 29, 2023

For this ticket:

https://app.asana.com/0/1201735099598270/1205276886236972/f

Summary:

  • Refactors the TextGeneration Constructor to remove the max_tokens argument and makes it part of the pipeline input
  • Adds num_generated_predictions to the input as well, which dictates the number of sequences that are generated for a given input. Similar to the hugging face implementation, we repeat the input based on the number provided, defaulting to 1. When num_generated_predictions is > 1, the engine's deterministic property is togged to False
  • In the case where the value is > 1, the output is a list of lists, where each list includes the generated sequences for a given prompt. This updates the sequences output to be one of str, List[str], and List[List[str]]

Testing

  • Tested locally using the new input arguments
  • Also added new tests to evaluate the num_generated_predictions

@dsikka dsikka changed the title Update max token [TextGeneration] max token refactor Aug 29, 2023
@dsikka dsikka marked this pull request as ready for review August 29, 2023 22:52
Copy link
Contributor

@dbogunowicz dbogunowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The good: much less additional code and complexity that I thought
The bad and ugly: could you add appropriate tests in tests/deepsparse/transformers/pipelines/test_text_generation.py ?

bfineran
bfineran previously approved these changes Sep 1, 2023
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could max_tokens default to the prompt_length - sequence_length so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache

@dsikka
Copy link
Contributor Author

dsikka commented Sep 7, 2023

Could max_tokens default to the prompt_length - sequence_length so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache

@dbogunowicz would like to get your opinion on this

@dbogunowicz
Copy link
Contributor

@dsikka this is a very good idea.

- Remove max_generated_tokens from the constructor and add it to the TextGenerationInput Schema
- Add num_generated_predictions to the TextGenerationInput which if > 1, repeats the input sequence and turns off deterministic prediction. If a sequence is already provided multiple times, the sequence is not repeated.
@dsikka
Copy link
Contributor Author

dsikka commented Sep 12, 2023

Could max_tokens default to the prompt_length - sequence_length so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache

Talking to the MLE team, I think for now we want to keep the defaults as is and update them once we've established best practices.

@dsikka dsikka requested a review from bfineran September 12, 2023 16:39
Satrat
Satrat previously approved these changes Sep 12, 2023
@dsikka dsikka merged commit a49ab47 into main Sep 12, 2023
@dsikka dsikka deleted the update_max_token branch September 12, 2023 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants