-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
113 lines (96 loc) · 3.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from base import Agent
from execution_pipeline import main
class ClassificationAgent(Agent):
"""
An agent that classifies text into one of the labels in the given label set.
"""
def __init__(self, config: dict) -> None:
"""
Initialize your LLM here
"""
# TODO
raise NotImplementedError
def __call__(
self,
label2desc: dict[str, str],
text: str
) -> str:
"""
Classify the text into one of the labels.
Args:
label2desc (dict[str, str]): A dictionary mapping each label to its description.
text (str): The text to classify.
Returns:
str: The label (should be a key in label2desc) that the text is classified into.
For example:
label2desc = {
"apple": "A fruit that is typically red, green, or yellow.",
"banana": "A long curved fruit that grows in clusters and has soft pulpy flesh and yellow skin when ripe.",
"cherry": "A small, round stone fruit that is typically bright or dark red.",
}
text = "The fruit is red and about the size of a tennis ball."
label = "apple" (should be a key in label2desc, i.e., ["apple", "banana", "cherry"])
"""
# TODO
raise NotImplementedError
def update(self, correctness: bool) -> bool:
"""
Update your LLM agent based on the correctness of its own prediction at the current time step.
Args:
correctness (bool): Whether the prediction is correct.
Returns:
bool: Whether the prediction is correct.
"""
# TODO
raise NotImplementedError
class SQLGenerationAgent(Agent):
"""
An agent that generates SQL code based on the given table schema and the user query.
"""
def __init__(self, config: dict) -> None:
"""
Initialize your LLM here
"""
# TODO
raise NotImplementedError
def __call__(
self,
table_schema: str,
user_query: str
) -> str:
"""
Generate SQL code based on the given table schema and the user query.
Args:
table_schema (str): The table schema.
user_query (str): The user query.
Returns:
str: The SQL code that the LLM generates.
"""
# TODO: Note that your output should be a valid SQL code only.
raise NotImplementedError
def update(self, correctness: bool) -> bool:
"""
Update your LLM agent based on the correctness of its own SQL code at the current time step.
"""
# TODO
raise NotImplementedError
if __name__ == "__main__":
from argparse import ArgumentParser
from execution_pipeline import main
parser = ArgumentParser()
parser.add_argument('--bench_name', type=str, required=True)
args = parser.parse_args()
if args.bench_name.startswith("classification"):
agent_name = ClassificationAgent
elif args.bench_name.startswith("sql_generation"):
agent_name = SQLGenerationAgent
else:
raise ValueError(f"Invalid benchmark name: {args.bench_name}")
bench_cfg = {
'bench_name': args.bench_name
}
config = {
# TODO: specify your configs for the agent here
}
agent = agent_name(config)
main(agent, bench_cfg)