-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
Copy pathutils.py
130 lines (104 loc) · 3.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import re
from datetime import datetime
from pathlib import Path
import nbformat
import yaml
from loguru import logger as _logger
from nbclient import NotebookClient
from nbformat.notebooknode import NotebookNode
from metagpt.roles.role import Role
def load_data_config(file_path="data.yaml"):
with open(file_path, "r") as stream:
data_config = yaml.safe_load(stream)
return data_config
DATASET_CONFIG = load_data_config("datasets.yaml")
DATA_CONFIG = load_data_config()
DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"]
def get_mcts_logger():
logfile_level = "DEBUG"
name: str = None
current_date = datetime.now()
formatted_date = current_date.strftime("%Y%m%d")
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
# _logger.remove()
_logger.level("MCTS", color="<green>", no=25)
# _logger.add(sys.stderr, level=print_level)
_logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level)
_logger.propagate = False
return _logger
mcts_logger = get_mcts_logger()
def get_exp_pool_path(task_name, data_config, pool_name="analysis_pool"):
datasets_dir = data_config["datasets_dir"]
if task_name in data_config["datasets"]:
dataset = data_config["datasets"][task_name]
data_path = os.path.join(datasets_dir, dataset["dataset"])
else:
raise ValueError(
f"Dataset {task_name} not found in config file. Available datasets: {data_config['datasets'].keys()}"
)
exp_pool_path = os.path.join(data_path, f"{pool_name}.json")
if not os.path.exists(exp_pool_path):
return None
return exp_pool_path
def change_plan(role, plan):
print(f"Change next plan to: {plan}")
tasks = role.planner.plan.tasks
finished = True
for i, task in enumerate(tasks):
if not task.code:
finished = False
break
if not finished:
tasks[i].plan = plan
return finished
def is_cell_to_delete(cell: NotebookNode) -> bool:
if "outputs" in cell:
for output in cell["outputs"]:
if output and "traceback" in output:
return True
return False
def process_cells(nb: NotebookNode) -> NotebookNode:
new_cells = []
i = 1
for cell in nb["cells"]:
if cell["cell_type"] == "code" and not is_cell_to_delete(cell):
cell["execution_count"] = i
new_cells.append(cell)
i = i + 1
nb["cells"] = new_cells
return nb
def save_notebook(role: Role, save_dir: str = "", name: str = "", save_to_depth=False):
save_dir = Path(save_dir)
tasks = role.planner.plan.tasks
nb = process_cells(role.execute_code.nb)
os.makedirs(save_dir, exist_ok=True)
file_path = save_dir / f"{name}.ipynb"
nbformat.write(nb, file_path)
if save_to_depth:
clean_file_path = save_dir / f"{name}_clean.ipynb"
codes = [task.code for task in tasks if task.code]
clean_nb = nbformat.v4.new_notebook()
for code in codes:
clean_nb.cells.append(nbformat.v4.new_code_cell(code))
nbformat.write(clean_nb, clean_file_path)
async def load_execute_notebook(role):
tasks = role.planner.plan.tasks
codes = [task.code for task in tasks if task.code]
executor = role.execute_code
executor.nb = nbformat.v4.new_notebook()
executor.nb_client = NotebookClient(executor.nb, timeout=role.role_timeout)
# await executor.build()
for code in codes:
outputs, success = await executor.run(code)
print(f"Execution success: {success}, Output: {outputs}")
print("Finish executing the loaded notebook")
return executor
def clean_json_from_rsp(text):
pattern = r"```json(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
json_str = "\n".join(matches)
return json_str
else:
return ""