-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathget_sessions.py
177 lines (145 loc) · 5.1 KB
/
get_sessions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from pyspark import SparkConf, SparkContext
import json
import argparse
import datetime
import os
import time
import subprocess
"""
Given a hive table of requests grouped by IP, UA, XFF and day:
1. convert hive data format to json
2. break each days worth of request for each "client" into sessions
3. remove Main-page and non-main namespace article from sessions
4. Output [{'lang': wikipedia lang, 'title':article title, 'id': wikidata id}]
for each session, ordered by time of request
Usage:
spark-submit \
--driver-memory 5g \
--master yarn \
--deploy-mode client \
--num-executors 10 \
--executor-memory 10g \
--executor-cores 4 \
--queue priority \
get_sessions.py \
--release test \
--lang en
"""
def parse_requests(requests):
ret = []
for r in requests.split('||'):
t = r.split('|')
if (len(t) % 2) != 0: # should be list of (name, value) pairs and contain at least id,ts,title
continue
data_dict = {t[i]:t[i+1] for i in range(0, len(t), 2) }
ret.append(data_dict)
ret.sort(key = lambda x: x['ts']) # sort by time
return ret
def parse_dates(requests):
clean = []
for r in requests:
try:
r['ts'] = datetime.datetime.strptime(r['ts'], '%Y-%m-%d %H:%M:%S')
clean.append(r)
except:
pass
return clean
def sessionize(requests):
"""
Break request stream whenever
there is 30 min gap in requests
"""
sessions = []
session = [requests[0]]
for r in requests[1:]:
d = r['ts'] - session[-1]['ts']
if d > datetime.timedelta(minutes=30):
sessions.append(session)
session = [r,]
else:
session.append(r)
sessions.append(session)
return sessions
def filter_consecutive_articles(requests):
"""
Looking at the data, there are a lot of
sessions with the same article
requested 2 times in a row. This
does not make sense for training, so
lets collapse them into 1 request
"""
r = requests[0]
t = r['title']
clean_rs = [r,]
prev_t = t
for r in requests[1:]:
t = r['title']
if t == prev_t:
continue
else:
clean_rs.append(r)
prev_t = t
return clean_rs
def filter_blacklist(requests):
"""
If the session contains an article in the blacklist,
drop the session. Currently, only the Main Page is
in the black list
"""
black_list = set(['Q5296',])
for r in requests:
if r['id'] in black_list:
return False
return True
def scrub_dates(requests):
for r in requests:
del r['ts']
return requests
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--request_db', default='a2v', help='hive db')
parser.add_argument('--release', required=True, help='hive table')
parser.add_argument('--lang', required=True, default = 'wikidata', help='wikidata will use all langs')
args = vars(parser.parse_args())
args['table'] = args['release'].replace('-', '_') + '_requests'
# create base dirs
base_dir = '/user/ellery/a2v/data/%(release)s' % args
print(os.system('hadoop fs -mkdir ' + base_dir) )
local_base_dir = '/a/ellery/a2v/data/%(release)s' % args
print(os.system('mkdir ' + local_base_dir) )
# define io paths
args['input_dir'] = '/user/hive/warehouse/%(request_db)s.db/%(table)s' % args
args['output_dir'] = '/user/ellery/a2v/data/%(release)s/%(release)s_sessions_%(lang)s' % args
args['local_output_file'] = '/a/ellery/a2v/data/%(release)s/%(release)s_sessions_%(lang)s' % args
args['local_output_dir'] = '/a/ellery/a2v/data/%(release)s/%(release)s_sessions_%(lang)s_dir' % args
# clean up old data
print (os.system('hadoop fs -rm -r %(output_dir)s' % args))
print(os.system('rm -rf %(local_output_file)s' % args))
print(os.system('rm -rf %(local_output_dir)s' % args))
conf = SparkConf()
conf.set("spark.app.name", 'a2v preprocess')
sc = SparkContext(conf=conf, pyFiles=[])
requests = sc.textFile(args['input_dir']) \
.map(parse_requests)
if args['lang'] != 'wikidata':
requests = requests.map(lambda rs: [r for r in rs if r['lang'] == args['lang']])
if args['lang'] == 'wikidata':
to_str = lambda x: ' '.join([e['id'] for e in x])
else:
to_str = lambda x: ' '.join([e['title'] for e in x])
requests \
.filter(filter_blacklist) \
.filter(lambda x: len(x) > 1) \
.map(filter_consecutive_articles) \
.filter(lambda x: len(x) > 1) \
.map(parse_dates) \
.flatMap(sessionize) \
.filter(lambda x: len(x) > 1) \
.filter(lambda x: len(x) < 30) \
.map(scrub_dates) \
.map(to_str) \
.saveAsTextFile (args['output_dir'], compressionCodecClass = "org.apache.hadoop.io.compress.GzipCodec")
# transfer data to local
os.system('hadoop fs -copyToLocal %(output_dir)s %(local_output_dir)s' % args)
os.system('cat %(local_output_dir)s/* | gunzip > %(local_output_file)s' % args)
os.system('rm -rf %(local_output_dir)s' % args)