Skip to content

Commit

Permalink
fix PP : avoid using sampling metadata cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Aug 25, 2024
1 parent 6eac258 commit 976d032
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 3 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,9 @@ def prepare_model_input(
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators, self.sampling_metadata_cache)
generators,
# TODO(varun) : Fix sampling metadata cache impl.
None)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def prepare_model_input(
frozen_model_input.seq_lens[num_prompts:],
frozen_model_input.query_lens[num_prompts:],
self.device, self.pin_memory, generators,
self.sampling_metadata_cache)
# TODO (varun) : Fix sampling metadata cache impl
None)
sampling_metadata_decodes.skip_sampler_cpu_output = (True)

model_input = StatefulModelInput(
Expand Down

0 comments on commit 976d032

Please sign in to comment.