Skip to content

Commit

Permalink
refactor glm_handler to simplify logic and apply fix
Browse files Browse the repository at this point in the history
  • Loading branch information
HuanzhiMao committed Jul 21, 2024
1 parent 83912f0 commit 8034aed
Showing 1 changed file with 18 additions and 79 deletions.
97 changes: 18 additions & 79 deletions berkeley-function-call-leaderboard/model_handler/glm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
class GLMHandler(OSSHandler):
def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.tensor_parallel_size = 8
self.max_model_len=4096
self.stop_token_ids = [151329, 151336, 151338]


def apply_chat_template(self, prompt, function, test_category):
oai_tool = convert_to_tool(
function, GORILLA_TO_OPENAPI, ModelStyle.OpenAI, test_category, True
Expand All @@ -32,94 +34,31 @@ def apply_chat_template(self, prompt, function, test_category):
conversation, tokenize=False, add_generation_prompt=True
)

def _batch_generate(
self,
question_jsons,
test_category,
model_path,
temperature,
max_tokens,
top_p,
index,
llm,
):
from vllm import SamplingParams

prompts = []
ans_jsons = []
for line in question_jsons:
for key, value in FILENAME_INDEX_MAPPING.items():
start, end = value
if index >= start and index < end:
test_category = key
break
prompts.append(line)
ans_id = shortuuid.uuid()
ans_jsons.append(
{
"answer_id": ans_id,
"question": line,
}
)

print("start generating: ", len(prompts))
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop_token_ids=stop_token_ids,
)
outputs = llm.generate(prompts, sampling_params)
final_ans_jsons = []
for output, ans_json in zip(outputs, ans_jsons):
text = output.outputs[0].text
ans_json["text"] = text
final_ans_jsons.append(ans_json)
return final_ans_jsons


def inference(self, test_question, test_category, num_gpus):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, trust_remote_code=True
)

chat_template_ques_jsons = []
for line in test_question:
prompt = augment_prompt_by_languge(line["question"], test_category)
function = language_specific_pre_processing(
line["function"], test_category, False
)
chat_template_ques_jsons.append(
self.apply_chat_template(prompt, function, test_category)
)

chunk_size = len(test_question) // num_gpus
from vllm import LLM

llm = LLM(
model=self.model_name,
dtype="float16",
trust_remote_code=True,
tensor_parallel_size=self.tensor_parallel_size,
max_model_len=4096,
test_question = self.process_input(
test_question, test_category, self.apply_chat_template
)

ans_jsons = self._batch_generate(
test_question=test_question,
model_path=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
stop_token_ids=self.stop_token_ids,
max_model_len=self.max_model_len,
num_gpus=num_gpus,
)
ans_jsons = []
for i in range(0, len(test_question), chunk_size):
output = self._batch_generate(
chat_template_ques_jsons[i : i + chunk_size],
test_category,
self.model_name,
self.temperature,
self.max_tokens,
self.top_p,
i,
llm,
)
ans_jsons.extend(output)

return ans_jsons, {"input_tokens": 0, "output_tokens": 0, "latency": 0}


def decode_ast(self, result, language="Python"):
args = result.split("\n")
if len(args) == 1:
Expand Down

0 comments on commit 8034aed

Please sign in to comment.