-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgenerate_data.py
66 lines (58 loc) · 2.27 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
from datetime import datetime
from configs.generate_data import sw, tcga
from experiments.io import save_dataset
from experiments.logger import create_logger
from experiments.utils import init_seeds
from simulation.utils import (create_dataset, get_data_generator,
get_outcome_generator,
get_treatment_assignment_policy)
def parse_default_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="StructuredInterventionNetworks")
parser.add_argument(
"--name", type=str, default="{:%Y_%m_%d_%H_%M_%S_%f}".format(datetime.now())
)
parser.add_argument("--task", type=str, default="tcga", choices=["sw", "tcga"])
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--cuda", type=int, default=0)
parser.add_argument(
"--data_path",
type=str,
default="./generated_data/",
help="Path to save/load generated data",
)
args, _ = parser.parse_known_args()
if args.task == "sw":
sw.add_params(parser)
elif args.task == "tcga":
tcga.add_params(parser)
args = parser.parse_args()
return args
def main() -> None:
args = parse_default_args()
logger = create_logger("log/%s.log" % args.name)
logger.info(args)
init_seeds(args.seed)
in_sample_dataset, out_sample_dataset, id_to_graph_dict = create_dataset(args)
all_treatment_ids = list(id_to_graph_dict.keys())
treatment_assignment_policy = get_treatment_assignment_policy(
treatment_ids=all_treatment_ids, args=args
)
outcome_generator = get_outcome_generator(
id_to_graph_dict=id_to_graph_dict, args=args
)
data_generator = get_data_generator(
task=args.task,
id_to_graph_dict=id_to_graph_dict,
treatment_assignment_policy=treatment_assignment_policy,
outcome_generator=outcome_generator,
in_sample_dataset=in_sample_dataset,
out_sample_dataset=out_sample_dataset,
args=args,
)
logger.info("Generate outcomes...")
data_generator.generate_train_data()
test_units = data_generator.generate_test_data()
save_dataset(in_sample_dataset=in_sample_dataset, test_units=test_units, args=args)
if __name__ == "__main__":
main()