Skip to content

Commit

Permalink
simplify _batch_generate logic; seperate out process_input section
Browse files Browse the repository at this point in the history
  • Loading branch information
HuanzhiMao committed Jul 20, 2024
1 parent c5ac395 commit 7d08daf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 31 deletions.
61 changes: 31 additions & 30 deletions berkeley-function-call-leaderboard/model_handler/oss_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,57 +31,58 @@ def _format_prompt(prompt, function, test_category):

@staticmethod
def _batch_generate(
question_jsons,
test_category,
test_question,
model_path,
temperature,
max_tokens,
top_p,
format_prompt_func,
num_gpus,
stop_token_ids=None,
max_model_len=None,
num_gpus=8,
):
from vllm import LLM, SamplingParams

prompts = []
ans_jsons = []
for line in question_jsons:
ques_json = line
prompt = augment_prompt_by_languge(ques_json["question"], test_category)
functions = language_specific_pre_processing(
ques_json["function"], test_category, False
)
prompts.append(format_prompt_func(prompt, functions, test_category))
ans_id = shortuuid.uuid()
ans_jsons.append(
{
"answer_id": ans_id,
"question": ques_json["question"],
}
)

print("start generating: ", len(prompts))

print("start generating, test question length: ", len(test_question))

sampling_params = SamplingParams(
temperature=temperature, max_tokens=max_tokens, top_p=top_p
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop_token_ids=stop_token_ids,
)
llm = LLM(
model=model_path,
dtype="float16",
trust_remote_code=True,
tensor_parallel_size=num_gpus,
disable_custom_all_reduce=True,
max_model_len=max_model_len,
tensor_parallel_size=num_gpus,
)
outputs = llm.generate(prompts, sampling_params)
outputs = llm.generate(test_question, sampling_params)

final_ans_jsons = []
for output, ans_json in zip(outputs, ans_jsons):
for output in outputs:
text = output.outputs[0].text
ans_json["result"] = text
final_ans_jsons.append(ans_json)
final_ans_jsons.append(text)
return final_ans_jsons

@staticmethod
def process_input(test_question, test_category, format_prompt_func):
prompts = []
for ques_json in test_question:
prompt = augment_prompt_by_languge(ques_json["question"], test_category)
functions = language_specific_pre_processing(
ques_json["function"], test_category, False
)
prompts.append(format_prompt_func(prompt, functions, test_category))

return prompts

def inference(
self, test_question, test_category, num_gpus, format_prompt_func=_format_prompt
):

test_question = self.process_input(test_question, test_category, format_prompt_func)

ans_jsons = self._batch_generate(
test_question,
test_category,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load_file(test_categories):
num_gpus = args.num_gpus,
)
for index, res in enumerate(result):
result_to_write = {"id": index, "result": res["result"]}
result_to_write = {"id": index, "result": res}
handler.write(result_to_write, file_to_open)
else:
for index, test_case in enumerate(tqdm(test_cases)):
Expand Down

0 comments on commit 7d08daf

Please sign in to comment.