-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathtest_agent.py
243 lines (212 loc) · 9.9 KB
/
test_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from __future__ import absolute_import, division, print_function
import sys
import os
import argparse
from math import log
from datetime import datetime
from tqdm import tqdm
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
import threading
from functools import reduce
from knowledge_graph import KnowledgeGraph
from kg_env import BatchKGEnvironment
from train_agent import ActorCritic
from utils import *
def evaluate(topk_matches, test_user_products):
"""Compute metrics for predicted recommendations.
Args:
topk_matches: a list or dict of product ids in ascending order.
"""
invalid_users = []
# Compute metrics
precisions, recalls, ndcgs, hits = [], [], [], []
test_user_idxs = list(test_user_products.keys())
for uid in test_user_idxs:
if uid not in topk_matches or len(topk_matches[uid]) < 10:
invalid_users.append(uid)
continue
pred_list, rel_set = topk_matches[uid][::-1], test_user_products[uid]
if len(pred_list) == 0:
continue
dcg = 0.0
hit_num = 0.0
for i in range(len(pred_list)):
if pred_list[i] in rel_set:
dcg += 1. / (log(i + 2) / log(2))
hit_num += 1
# idcg
idcg = 0.0
for i in range(min(len(rel_set), len(pred_list))):
idcg += 1. / (log(i + 2) / log(2))
ndcg = dcg / idcg
recall = hit_num / len(rel_set)
precision = hit_num / len(pred_list)
hit = 1.0 if hit_num > 0.0 else 0.0
ndcgs.append(ndcg)
recalls.append(recall)
precisions.append(precision)
hits.append(hit)
avg_precision = np.mean(precisions) * 100
avg_recall = np.mean(recalls) * 100
avg_ndcg = np.mean(ndcgs) * 100
avg_hit = np.mean(hits) * 100
print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} | Invalid users={}'.format(
avg_ndcg, avg_recall, avg_hit, avg_precision, len(invalid_users)))
def batch_beam_search(env, model, uids, device, topk=[25, 5, 1]):
def _batch_acts_to_masks(batch_acts):
batch_masks = []
for acts in batch_acts:
num_acts = len(acts)
act_mask = np.zeros(model.act_dim, dtype=np.uint8)
act_mask[:num_acts] = 1
batch_masks.append(act_mask)
return np.vstack(batch_masks)
state_pool = env.reset(uids) # numpy of [bs, dim]
path_pool = env._batch_path # list of list, size=bs
probs_pool = [[] for _ in uids]
model.eval()
for hop in range(3):
state_tensor = torch.FloatTensor(state_pool).to(device)
acts_pool = env._batch_get_actions(path_pool, False) # list of list, size=bs
actmask_pool = _batch_acts_to_masks(acts_pool) # numpy of [bs, dim]
actmask_tensor = torch.ByteTensor(actmask_pool).to(device)
probs, _ = model((state_tensor, actmask_tensor)) # Tensor of [bs, act_dim]
probs = probs + actmask_tensor.float() # In order to differ from masked actions
topk_probs, topk_idxs = torch.topk(probs, topk[hop], dim=1) # LongTensor of [bs, k]
topk_idxs = topk_idxs.detach().cpu().numpy()
topk_probs = topk_probs.detach().cpu().numpy()
new_path_pool, new_probs_pool = [], []
for row in range(topk_idxs.shape[0]):
path = path_pool[row]
probs = probs_pool[row]
for idx, p in zip(topk_idxs[row], topk_probs[row]):
if idx >= len(acts_pool[row]): # act idx is invalid
continue
relation, next_node_id = acts_pool[row][idx] # (relation, next_node_id)
if relation == SELF_LOOP:
next_node_type = path[-1][1]
else:
next_node_type = KG_RELATION[path[-1][1]][relation]
new_path = path + [(relation, next_node_type, next_node_id)]
new_path_pool.append(new_path)
new_probs_pool.append(probs + [p])
path_pool = new_path_pool
probs_pool = new_probs_pool
if hop < 2:
state_pool = env._batch_get_state(path_pool)
return path_pool, probs_pool
def predict_paths(policy_file, path_file, args):
print('Predicting paths...')
env = BatchKGEnvironment(args.dataset, args.max_acts, max_path_len=args.max_path_len, state_history=args.state_history)
pretrain_sd = torch.load(policy_file)
model = ActorCritic(env.state_dim, env.act_dim, gamma=args.gamma, hidden_sizes=args.hidden).to(args.device)
model_sd = model.state_dict()
model_sd.update(pretrain_sd)
model.load_state_dict(model_sd)
test_labels = load_labels(args.dataset, 'test')
test_uids = list(test_labels.keys())
batch_size = 16
start_idx = 0
all_paths, all_probs = [], []
pbar = tqdm(total=len(test_uids))
while start_idx < len(test_uids):
end_idx = min(start_idx + batch_size, len(test_uids))
batch_uids = test_uids[start_idx:end_idx]
paths, probs = batch_beam_search(env, model, batch_uids, args.device, topk=args.topk)
all_paths.extend(paths)
all_probs.extend(probs)
start_idx = end_idx
pbar.update(batch_size)
predicts = {'paths': all_paths, 'probs': all_probs}
pickle.dump(predicts, open(path_file, 'wb'))
def evaluate_paths(path_file, train_labels, test_labels):
embeds = load_embed(args.dataset)
user_embeds = embeds[USER]
purchase_embeds = embeds[PURCHASE][0]
product_embeds = embeds[PRODUCT]
scores = np.dot(user_embeds + purchase_embeds, product_embeds.T)
# 1) Get all valid paths for each user, compute path score and path probability.
results = pickle.load(open(path_file, 'rb'))
pred_paths = {uid: {} for uid in test_labels}
for path, probs in zip(results['paths'], results['probs']):
if path[-1][1] != PRODUCT:
continue
uid = path[0][2]
if uid not in pred_paths:
continue
pid = path[-1][2]
if pid not in pred_paths[uid]:
pred_paths[uid][pid] = []
path_score = scores[uid][pid]
path_prob = reduce(lambda x, y: x * y, probs)
pred_paths[uid][pid].append((path_score, path_prob, path))
# 2) Pick best path for each user-product pair, also remove pid if it is in train set.
best_pred_paths = {}
for uid in pred_paths:
train_pids = set(train_labels[uid])
best_pred_paths[uid] = []
for pid in pred_paths[uid]:
if pid in train_pids:
continue
# Get the path with highest probability
sorted_path = sorted(pred_paths[uid][pid], key=lambda x: x[1], reverse=True)
best_pred_paths[uid].append(sorted_path[0])
# 3) Compute top 10 recommended products for each user.
sort_by = 'score'
pred_labels = {}
for uid in best_pred_paths:
if sort_by == 'score':
sorted_path = sorted(best_pred_paths[uid], key=lambda x: (x[0], x[1]), reverse=True)
elif sort_by == 'prob':
sorted_path = sorted(best_pred_paths[uid], key=lambda x: (x[1], x[0]), reverse=True)
top10_pids = [p[-1][2] for _, _, p in sorted_path[:10]] # from largest to smallest
# add up to 10 pids if not enough
if args.add_products and len(top10_pids) < 10:
train_pids = set(train_labels[uid])
cand_pids = np.argsort(scores[uid])
for cand_pid in cand_pids[::-1]:
if cand_pid in train_pids or cand_pid in top10_pids:
continue
top10_pids.append(cand_pid)
if len(top10_pids) >= 10:
break
# end of add
pred_labels[uid] = top10_pids[::-1] # change order to from smallest to largest!
evaluate(pred_labels, test_labels)
def test(args):
policy_file = args.log_dir + '/policy_model_epoch_{}.ckpt'.format(args.epochs)
path_file = args.log_dir + '/policy_paths_epoch{}.pkl'.format(args.epochs)
train_labels = load_labels(args.dataset, 'train')
test_labels = load_labels(args.dataset, 'test')
if args.run_path:
predict_paths(policy_file, path_file, args)
if args.run_eval:
evaluate_paths(path_file, train_labels, test_labels)
if __name__ == '__main__':
boolean = lambda x: (str(x).lower() == 'true')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {cloth, beauty, cell, cd}')
parser.add_argument('--name', type=str, default='train_agent', help='directory name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
parser.add_argument('--epochs', type=int, default=50, help='num of epochs.')
parser.add_argument('--max_acts', type=int, default=250, help='Max number of actions.')
parser.add_argument('--max_path_len', type=int, default=3, help='Max path length.')
parser.add_argument('--gamma', type=float, default=0.99, help='reward discount factor.')
parser.add_argument('--state_history', type=int, default=1, help='state history length')
parser.add_argument('--hidden', type=int, nargs='*', default=[512, 256], help='number of samples')
parser.add_argument('--add_products', type=boolean, default=False, help='Add predicted products up to 10')
parser.add_argument('--topk', type=int, nargs='*', default=[25, 5, 1], help='number of samples')
parser.add_argument('--run_path', type=boolean, default=True, help='Generate predicted path? (takes long time)')
parser.add_argument('--run_eval', type=boolean, default=True, help='Run evaluation?')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
args.log_dir = TMP_DIR[args.dataset] + '/' + args.name
test(args)