-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmulti_node2vec.py
99 lines (73 loc) · 4.32 KB
/
multi_node2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
'''
Wrapper for the multi-node2vec algorithm.
Details can be found in the paper: "Fast Embedding of Multilayer Networks: An Algorithm and Application to Group fMRI"
by JD Wilson, M Baybay, R Sankar, and P Stillman
Preprint here: https://arxiv.org/pdf/1809.06437.pdf
Contributors:
- Melanie Baybay
University of San Francisco, Department of Computer Science
- Rishi Sankar
Henry M. Gunn High School
- James D. Wilson (maintainer)
University of San Francisco, Department of Mathematics and Statistics
Questions or Bugs? Contact James D. Wilson at [email protected]
'''
import os
import src as mltn2v
import argparse
import time
def parse_args():
parser = argparse.ArgumentParser(description="Run multi-node2vec on multilayer networks.")
parser.add_argument('--dir', nargs='?', default='data/CONTROL_fmt',
help='Absolute path to directory of correlation/adjacency matrix files (csv format). Note that rows and columns must be properly labeled by node ID in each .csv.')
parser.add_argument('--output', nargs='?', default='new_results/',
help='Absolute path to output directory (no extension).')
#parser.add_argument('--filename', nargs='?', default='new_results/mltn2v_control',
# help='output filename (no extension).')
parser.add_argument('--d', type=int, default=100,
help='Dimensionality. Default is 100.')
parser.add_argument('--walk_length', type=int, default=100,
help='Length of each random walk. Default is 100.')
parser.add_argument('--window_size', type=int, default = 10,
help='Size of context window used for Skip Gram optimization. Default is 10.')
parser.add_argument('--n_samples', type=int, default=1,
help='Number of walks per node per layer. Default is 1.')
parser.add_argument('--thresh', type=float, default=0.5,
help='Threshold for converting a weighted network to an unweighted one. All weights less than or equal to thresh will be considered 0 and all others 1. Default is 0.5. Use None if the network is unweighted.')
# parser.add_argument('--w2v_iter', default=1, type=int,
# help='Number of epochs in word2vec')
parser.add_argument('--w2v_workers', type=int, default=8,
help='Number of parallel worker threads. Default is 8.')
parser.add_argument('--rvals', type=float, default=0.25,
help='Layer walk parameter for neighborhood search. Default is 0.25')
parser.add_argument('--pvals', type=float, default=1,
help='Return walk parameter for neighborhood search. Default is 1')
parser.add_argument('--qvals', type=float, default=0.5,
help='Exploration walk parameter for neighborhood search. Default is 0.50')
return parser.parse_args()
def main(args):
start = time.time()
# PARSE LAYERS -- THRESHOLD & CONVERT TO BINARY
layers = mltn2v.timed_invoke("parsing network layers",
lambda: mltn2v.parse_matrix_layers(args.dir, binary=True, thresh=args.thresh))
# check if layers were parsed
if layers:
# EXTRACT NEIGHBORHOODS
nbrhd_dict = mltn2v.timed_invoke("extracting neighborhoods",
lambda: mltn2v.extract_neighborhoods_walk(layers, args.walk_length, args.rvals, args.pvals, args.qvals))
# GENERATE FEATURES
out = mltn2v.clean_output(args.output)
for w in args.rvals:
out_path = os.path.join(out, 'r' + str(w) + '/mltn2v_results')
mltn2v.timed_invoke("generating features",
lambda: mltn2v.generate_features(nbrhd_dict[w], args.d, out_path, nbrhd_size=args.window_size,
w2v_iter=1, workers=args.w2v_workers))
print("\nCompleted Multilayer Network Embedding for r=" + str(w) + " in {:.2f} secs.\nSee results:".format(time.time() - start))
print("\t" + out_path + ".csv")
print("Completed Multilayer Network Embedding for all r values.")
else:
print("Whoops!")
if __name__ == '__main__':
args = parse_args()
args.rvals = [args.rvals]
main(args)