Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][Text Generation Pipeline] Fix the erroneous sampling logic #1406

Merged
merged 1 commit into from
Nov 15, 2023

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented Nov 15, 2023

Fix Description

Before: regardless of whether sampling=True or False we would do top_k and top_p sampling.
Now: if sampling=False, we directly "jump" to the argmax function and avoid any sampling logic.

@horheynm Could you please validate the rest of the logic in def generate(self, logits: numpy.ndarray)? In the most complex scenario, we can apply both top_k, top_p, and sampling_temperature sequentially to our logits. Let's make sure that the order in which the sampling functions are applied matches the one defined in HF (I assume this is the original implementation that we want to mimic).

@rahul-tuli
Copy link
Member

Could you add output before and after?

@dbogunowicz dbogunowicz merged commit fec7650 into main Nov 15, 2023
13 checks passed
@dbogunowicz dbogunowicz deleted the fix/damian/sampling branch November 15, 2023 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants