-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess_openai_few_shot_ner_properties.py
89 lines (67 loc) · 3.66 KB
/
process_openai_few_shot_ner_properties.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import csv
import os
from pathlib import Path
import dotenv
dotenv.load_dotenv(override=True)
from document_qa.grobid_processors import GrobidQuantitiesProcessor
from grobid_quantities.quantities import QuantitiesAPI
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import OutputParserException
from commons.grobid.grobid_client_generic import GrobidClientGeneric
from commons.openai import CHATS
from llm_mat_evaluation.ner.process_openai_ner_properties import prepare_data, extract_entities, \
PROMPT_TEMPLATE_CHAT_USER_QUANTITIES, ListOfQuantitiesOutputParser, _parse_json
from tqdm import tqdm
from commons.reader import get_last_id
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Data preparation for the properties extraction using OpenAI LLMs")
parser.add_argument("--input", help="Input CSV/TSV file", required=True)
parser.add_argument("--output", help="Output file, support both JSON, CSV, or TSV", required=True)
parser.add_argument("--config", help="Configuration file", default="resources/config/config.yaml")
parser.add_argument("--model", choices=CHATS.keys(), default="gpt35_turbo")
args = parser.parse_args()
input = args.input
output = args.output
config_file = args.config
model = args.model
llm = CHATS[model]
if 'pl_tags' in llm:
llm.pl_tags.append("evaluation")
llm.pl_tags.append("ner")
llm.pl_tags.append("quantities")
config = GrobidClientGeneric().load_yaml_config_from_file(config_file)
quantities_client = QuantitiesAPI(config['quantities']['server'], check_server=True)
grobid_quantities_processor = GrobidQuantitiesProcessor(quantities_client)
input_path = Path(input)
output_path = Path(output)
if os.path.isdir(str(output)):
output_path = os.path.join(output, "{}.{}.few-shot.output.csv".format(input_path.stem, model))
else:
output_path = Path(output)
last_id_quantities = get_last_id(Path(output_path))
data_input = prepare_data(input)
with open(output_path, encoding="utf-8", mode='a') as foq:
fwq = csv.writer(foq, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
for idx, example in tqdm(enumerate(data_input), desc="record"):
id = example['id']
id_n = int(id)
if last_id_quantities > 0 and id_n <= int(last_id_quantities):
print("Skip quantity", id_n)
skip_quantities = True
filename = example['filename']
paragraph_id = int(example['pid'])
text = example['text']
hints = [entity['text'] for entity in grobid_quantities_processor.extract_quantities(text)]
try:
output_data_quantities = extract_entities(text, PROMPT_TEMPLATE_CHAT_USER_QUANTITIES, llm,
output_parser_class=ListOfQuantitiesOutputParser, hints=hints)
except OutputParserException as ope:
output_data_quantities_raw = extract_entities(text, PROMPT_TEMPLATE_CHAT_USER_QUANTITIES, llm, hints=hints)
if output_data_quantities_raw.startswith("I don't know") or output_data_quantities_raw.startswith("None"):
continue
output_parser = PydanticOutputParser(pydantic_object=ListOfQuantitiesOutputParser)
parsed_output = _parse_json(output_data_quantities_raw, llm, output_parser=output_parser)
output_data_quantities = ListOfQuantitiesOutputParser.parse_to_list(parsed_output)
fwq.writerows([[id, filename, paragraph_id, result] for result in output_data_quantities])