Skip to content

Commit

Permalink
polish process_input section
Browse files Browse the repository at this point in the history
  • Loading branch information
HuanzhiMao committed Jul 21, 2024
1 parent 7d08daf commit 83912f0
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions berkeley-function-call-leaderboard/model_handler/oss_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def _batch_generate(
num_gpus=8,
):
from vllm import LLM, SamplingParams

print("start generating, test question length: ", len(test_question))

sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -59,13 +59,13 @@ def _batch_generate(
tensor_parallel_size=num_gpus,
)
outputs = llm.generate(test_question, sampling_params)

final_ans_jsons = []
for output in outputs:
text = output.outputs[0].text
final_ans_jsons.append(text)
return final_ans_jsons

@staticmethod
def process_input(test_question, test_category, format_prompt_func):
prompts = []
Expand All @@ -79,19 +79,27 @@ def process_input(test_question, test_category, format_prompt_func):
return prompts

def inference(
self, test_question, test_category, num_gpus, format_prompt_func=_format_prompt
self,
test_question,
test_category,
num_gpus,
format_prompt_func=_format_prompt,
stop_token_ids=None,
max_model_len=None,
):
test_question = self.process_input(test_question, test_category, format_prompt_func)

test_question = self.process_input(
test_question, test_category, format_prompt_func
)

ans_jsons = self._batch_generate(
test_question,
test_category,
self.model_name,
self.temperature,
self.max_tokens,
self.top_p,
format_prompt_func,
num_gpus,
test_question=test_question,
model_path=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
stop_token_ids=stop_token_ids,
max_model_len=max_model_len,
num_gpus=num_gpus,
)

return ans_jsons, {"input_tokens": 0, "output_tokens": 0, "latency": 0}
Expand Down

0 comments on commit 83912f0

Please sign in to comment.