-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstep2a.py
60 lines (49 loc) · 2.04 KB
/
step2a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import json
from tqdm import tqdm
import torch
from transformers import pipeline
from transformers import AutoTokenizer
from dotenv import load_dotenv
from utils.argument import args
from utils.llm_utils import get_gpt_response, get_llama_response
if __name__ == "__main__":
if args.llama:
model = "meta-llama/Llama-2-70b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)
pipe_line = pipeline("text-generation", model=model, torch_dtype=torch.float16, device_map='auto')
# load_dotenv()
# api_key = os.getenv("API_KEY")
# user = os.getenv("USER")
# model = os.getenv("MODEL")
# url = ""
# if args.llama_ver == "llama_70b":
# url = os.getenv("LLAMA_70b_URL")
# elif args.llama_ver == "llama_13b":
# url = os.getenv("LLAMA_13b_URL")
# elif args.llama_ver == "llama_7b":
# url = os.getenv("LLAMA_7b_URL")
results = []
# read system prompt
with open(args.step2a_prompt_path, "r") as label_file:
system_prompt = label_file.read()
# read initial_answer.jsonl
with open(args.step1_result_path, "r") as answer_file:
answers = answer_file.readlines()
for i in tqdm(range(len(answers))):
user_prompt = json.loads(answers[i])["text"]
### Get LLM result ###
if args.llama:
response = get_llama_response(system_prompt, user_prompt, pipe_line, tokenizer)
else:
response = get_gpt_response(system_prompt, user_prompt, api_key, user, model)
if "image_file" in json.loads(answers[i]):
text = f"Image file-{json.loads(answers[i])['image_file']}; "
else:
text = ""
results.append(text+response)
### Save results ###
results = "\n".join(results)
# Open the file for writing and write the string to the file
with open(args.step2a_result_path, 'w') as file:
file.write(results)