forked from KongBOy/analyze_16_code_iin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathedsetup
84 lines (67 loc) · 3.02 KB
/
edsetup
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
import os
import shutil
import yaml
import edflow
def create_edflow_project(project_name, replace: bool = False, **kwargs):
"""
Creates a fresh new project directory with all the needed skeletons which are edflow compatible.
Parameters
----------
project_name (str) : Name of your new project. This will also be the name of the parent folder.
replace (bool) : If a old project shall be replaced in case it exists.
kwargs : optional file paths for the config, model, dataset and iterator classes as <class>_path="path to file"
e.g. config_path="config/config.yaml" EXCLUDING the project directory.
"""
if os.path.exists(project_name) and not replace:
raise IsADirectoryError(
f"Directory: {project_name} already exists. Provide a different name for the project "
"OR set replace to True"
)
os.mkdir(project_name)
parent_source = os.path.join(
edflow.__file__.replace("__init__.py", ""), "edsetup_files"
)
destination_config = os.path.join(
project_name, kwargs.get("config_path", "config.yaml")
)
source_config = os.path.join(parent_source, "config.yaml")
with open(source_config, "r+") as source_config_file:
source_config_dict = yaml.load(source_config_file, Loader=yaml.FullLoader)
path_keys = ["model_path", "dataset_path", "dataset_path", "iterator_path"]
path_defaults = ["model.py", "dataset.py", "dataset.py", "iterator.py"]
destination_training_files = list()
for key, default in zip(path_keys, path_defaults):
destination_path = os.path.join(project_name, kwargs.get(key, default))
source_path = os.path.join(parent_source, default)
shutil.copy(source_path, destination_path)
destination_training_files.append(destination_path)
training_files_to_module = [
parameter.replace("/", ".").replace("py", "")
for parameter in destination_training_files
]
training_classes = [
file_name.replace(project_name, "").strip(".").capitalize()
for file_name in training_files_to_module
]
full_address_to_class = [
file + class_name
for file, class_name in zip(training_files_to_module, training_classes)
]
training_parameters_dict = dict(
zip(["model", "train_dataset", "validation_dataset", "iterator"], full_address_to_class)
)
source_config_dict["datasets"]["train"] = training_parameters_dict.pop("train_dataset")
source_config_dict["datasets"]["validation"] = training_parameters_dict.pop("validation_dataset")
source_config_dict.update(training_parameters_dict)
with open(destination_config, "w+") as new_config_file:
yaml.dump(source_config_dict, new_config_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-n", metavar="project_name", type=str, default="edflow_project"
)
args = parser.parse_args()
project_name = args.n
create_edflow_project(project_name)