-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrun.py
193 lines (161 loc) · 6.62 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import time
import torch
import openai
import hydra
import logging
import transformers
from tqdm import tqdm
from hydra import initialize
from omegaconf import DictConfig, open_dict
from distiller.db import get_database, get_all
from distiller.prompt.retrieval import get_task_class
from distiller.generator import Generator
from distiller.utils import write_json
from distiller.filter.core import FilterDataLoader
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger("distiller.runner")
transformers.logging.set_verbosity_error()
class DistillerRunner:
def __init__(self, args, output_cache):
self.args = args
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
self.device = "mps"
self.generator = Generator(args)
self.output_cache = output_cache
self.generation_outputs = []
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="./templates", version_base="1.3.2")
self.task_container = get_task_class(self.args.task_name)
def compose_loop(self):
logger.info("Loading input data and composing prompts ...")
querys = self.task_container.build_prompts(self.args)
logger.info(f"Composed {len(querys)} prompts.")
return querys
def generate_loop(self, querys):
logger.info("Generating augmentation data ...")
generation_outputs = self.generator.batch_generate(querys)
logger.info(f"Generated {len(generation_outputs)} data points.")
logger.info("Postprocessing generation outputs ...")
generation_outputs = self.task_container.postprocess_generation(
self.output_cache, generation_outputs)
logger.info("Postprocessing complete, updates persisted to cache.")
return generation_outputs
def filter_loop(self, generation_outputs):
logger.info("Filtering generation outputs ...")
dataloader = FilterDataLoader(generation_outputs, 16).dataloader
filters = self.task_container.FILTERS
filter_args = self.task_container.FILTER_ARGS
filter_args["device"] = self.device
for filter_class in filters:
task_filter = filter_class()
for batch in tqdm(dataloader):
task_filter.run(batch, self.output_cache, **filter_args)
logger.info("Filtering complete.")
def filter_all_loop(self):
logger.info("Filtering all previous generation outputs ...")
all_outputs = get_all(self.output_cache)
dataloader = FilterDataLoader(all_outputs, 16).dataloader
filters = self.task_container.FILTERS
filter_args = self.task_container.FILTER_ARGS
filter_args["device"] = self.device
for filter_class in filters:
task_filter = filter_class()
for batch in tqdm(dataloader):
task_filter.run(batch, self.output_cache, **filter_args)
logger.info("Filtering complete.")
self.report(all_outputs)
self.write_augmentations(all_outputs)
def report(self, outputs):
accepted = [record for record in outputs if record['accept']]
rejected = [record for record in outputs if not record['accept']]
logger.info(f"Report: {len(accepted)} accepted and {len(rejected)} rejected.")
return accepted
def write_outputs(self, generation_outputs):
meta_data = {
"dataset": self.args.dataset,
"template_name": self.args.template_name,
"gen_type": self.args.gen_type,
"engine": self.args.model_name,
"source_label": self.args.source_label,
"target_label": self.args.target_label,
"start": self.args.start,
"end": self.args.end
}
output_file = {
"meta_data": meta_data,
"outputs": generation_outputs
}
timestr = time.strftime("%Y%m%d-%H%M%S")
output_pth = os.path.join(
self.args.output_dir,
f"{timestr}.json",
)
write_json(output_file, output_pth,)
logger.info(f"Outputs written to {output_pth}.")
def write_augmentations(self, accepted):
meta_data = {
"dataset": self.args.dataset,
"template_name": self.args.template_name,
}
augment_file = {
"meta_data": meta_data,
"outputs": accepted
}
timestr = time.strftime("%Y%m%d-%H%M%S")
augment_pth = os.path.join(
self.args.aug_pth,
f"{timestr}.json",
)
write_json(augment_file, augment_pth)
logger.info(f"Augmentations written to {augment_pth}.")
def main_loop(self):
querys = self.compose_loop()
generation_outputs = self.generate_loop(querys)
self.filter_loop(generation_outputs)
self.write_outputs(generation_outputs)
all_outputs = get_all(self.output_cache)
accepted = self.report(all_outputs)
self.write_augmentations(accepted)
def setup_path(args):
category = f"{args.source_label}_{args.target_label}"
category_file = category+".jsonl"
output_dir = os.path.join(
args.data_dir, args.dataset, "output", category)
source = f"{args.source_label}.jsonl"
input_dir = os.path.join(
args.data_dir, args.dataset, "input", source)
demo_dir = os.path.join(
args.data_dir, args.dataset, "examples", category_file)
aug_dir = os.path.join(
args.data_dir, args.dataset, "augment", category)
if(not os.path.isdir(output_dir)):
os.makedirs(output_dir, exist_ok=True)
assert os.path.isfile(input_dir)
assert os.path.isfile(demo_dir)
if(not os.path.isdir(aug_dir)):
os.makedirs(aug_dir, exist_ok=True)
with open_dict(args):
args.output_dir = output_dir
args.input_pth = input_dir
args.demo_pth = demo_dir
args.aug_pth = aug_dir
@hydra.main(config_path="config/secret/", config_name="keys", version_base="1.3.2")
def set_up_api(keys: DictConfig):
openai.organization = keys.organization_token
openai.api_key = keys.api_token
@hydra.main(config_path="config", config_name="config", version_base="1.3.2")
def main(args: DictConfig):
output_cache = get_database(args.dataset, args.template_name)
runner = DistillerRunner(args, output_cache)
setup_path(args)
if args.filter_all:
runner.filter_all_loop()
else:
runner.main_loop()
if __name__ == "__main__":
set_up_api()
main()