-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnesi_base2.py
executable file
·112 lines (97 loc) · 2.98 KB
/
nesi_base2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Entrypoint for NeSI workers.
Takes the following CLI arguments:
"""
import argparse
import os
from dotenv import load_dotenv
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import librun
from libdatasets import *
from libadversarial import uncertainty_stop
matrix = {
# Dataset fetchers should cache if possible
# Lambda wrapper required for function to be pickleable (sent to other threads via joblib)
"datasets": [
("rcv1-58509", wrap(rcv1, 58509)),
("webkb", wrap(webkb, None)),
("spamassassin", wrap(spamassassin, None)),
("avila", wrap(avila, None)),
("smartphone", wrap(smartphone, None)),
("swarm", wrap(swarm, None)),
("sensorless", wrap(sensorless, None)),
("splice", wrap(splice, None)),
("anuran", wrap(anuran, None)),
],
"dataset_mutators": {
"none": (lambda *x, **kwargs: x),
},
"methods": [
("uncertainty", partial(uncertainty_stop, n_instances=10)),
],
"models": ["svm-linear", "random-forest", "neural-network"],
"meta": {
"dataset_size": 1000,
"labelled_size": 10,
"test_size": 0.5,
"n_runs": 10,
"ret_classifiers": True,
"ensure_y": True,
"stop_info": True,
"aggregate": False,
"stop_function": (
"res500",
lambda learner, matrix, state: state.X_unlabelled.shape[0] < 510,
),
"pool_subsample": 1000,
},
}
capture_metrics = [
accuracy_score,
f1_score,
roc_auc_score,
"time",
"time_total",
"uncertainty_average",
"uncertainty_min",
"uncertainty_max",
"uncertainty_variance",
"uncertainty_average_selected",
"uncertainty_min_selected",
"uncertainty_max_selected",
"uncertainty_variance_selected",
"entropy_max",
"n_support",
"contradictory_information",
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("fragment_id", type=int)
parser.add_argument("fragment_length", type=int)
parser.add_argument("fragment_run")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--workers", type=int, default=None)
parser.add_argument("--nobackup", action="store_true")
args = parser.parse_args()
fragment_run = args.fragment_run.split("-")
start = int(fragment_run[0])
if len(fragment_run) == 2:
end = int(fragment_run[1])
else:
end = None
if args.nobackup:
os.environ["OUT_DIR"] = "/home/zpul156/out_nobackup"
librun.run(
matrix,
metrics=capture_metrics,
# abort=False,
fragment_id=args.fragment_id,
fragment_length=args.fragment_length,
fragment_run_start=start,
fragment_run_end=end,
dry_run=args.dry_run,
workers=args.workers,
)
if __name__ == "__main__":
load_dotenv()
main()