-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
38 changed files
with
1,968 additions
and
875 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
*.pyc | ||
polara.egg-info/ | ||
examples/.ipynb_checkpoints/ | ||
.ipynb_checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# This file may be used to create an environment using: | ||
# $ conda create --name <env> --file <this file> | ||
|
||
python>=3.6 | ||
jupyter>=1.0.0 | ||
numba>=0.21.0 | ||
numpy>=1.10.1 | ||
matplotlib>=1.4.3 | ||
pandas>=0.17.1 | ||
requests>=2.7.0 | ||
scipy>=0.16.0 | ||
seaborn>=0.6.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
62 changes: 47 additions & 15 deletions
62
examples/Reproducing EIGENREC results.ipynb → examples/Reproducing_EIGENREC_results.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
# import standard baseline models | ||
from polara.recommender.models import RecommenderModel | ||
from polara.recommender.models import SVDModel | ||
from polara.recommender.models import CooccurrenceModel | ||
from polara.recommender.models import RandomModel | ||
from polara.recommender.models import PopularityModel | ||
from polara.recommender.models import (RecommenderModel, | ||
SVDModel, | ||
ScaledSVD, | ||
CooccurrenceModel, | ||
RandomModel, | ||
PopularityModel) | ||
|
||
# import data model | ||
from polara.recommender.data import RecommenderData | ||
|
||
# import data management routines | ||
from polara.datasets.movielens import get_movielens_data | ||
from polara.datasets.bookcrossing import get_bx_data | ||
from polara.datasets.bookcrossing import get_bookcrossing_data | ||
from polara.datasets.netflix import get_netflix_data | ||
from polara.datasets.amazon import get_amazon_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ast import literal_eval | ||
import gzip | ||
import pandas as pd | ||
|
||
|
||
def parse_meta(path): | ||
with gzip.open(path, 'rt') as gz: | ||
for line in gz: | ||
yield literal_eval(line) | ||
|
||
|
||
def get_amazon_data(path=None, meta_path=None, nrows=None): | ||
res = [] | ||
if path: | ||
data = pd.read_csv(path, header=None, | ||
names=['userid', 'asin', 'rating', 'timestamp'], | ||
usecols=['userid', 'asin', 'rating'], | ||
nrows=nrows) | ||
res.append(data) | ||
if meta_path: | ||
meta = pd.DataFrame.from_records(parse_meta(meta_path), nrows=nrows) | ||
res.append(meta) | ||
if len(res) == 1: | ||
res = res[0] | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
import scipy as sp | ||
import pandas as pd | ||
|
||
|
||
def compute_graph_laplacian(edges, index): | ||
all_edges = set() | ||
for a, b in edges: | ||
try: | ||
a = index.get_loc(a) | ||
b = index.get_loc(b) | ||
except KeyError: | ||
continue | ||
if a == b: # exclude self links | ||
continue | ||
# make graph undirectional | ||
all_edges.add((a, b)) | ||
all_edges.add((b, a)) | ||
|
||
sp_edges = sp.sparse.csr_matrix((np.ones(len(all_edges)), zip(*all_edges))) | ||
assert (sp_edges.diagonal() == 0).all() | ||
return sp.sparse.csgraph.laplacian(sp_edges).tocsr(), sp_edges | ||
|
||
|
||
def get_epinions_data(ratings_path=None, trust_data_path=None): | ||
res = [] | ||
if ratings_path: | ||
ratings = pd.read_csv(ratings_path, | ||
delim_whitespace=True, | ||
skiprows=[0], | ||
skipfooter=1, | ||
engine='python', | ||
header=None, | ||
skipinitialspace=True, | ||
names=['user', 'film', 'rating'], | ||
usecols=['user', 'film', 'rating']) | ||
res.append(ratings) | ||
|
||
if trust_data_path: | ||
edges = pd.read_table(trust_data_path, | ||
delim_whitespace=True, | ||
skiprows=[0], | ||
skipfooter=1, | ||
engine='python', | ||
header=None, | ||
skipinitialspace=True, | ||
usecols=[0, 1]) | ||
res.append(edges) | ||
|
||
if len(res)==1: res = res[0] | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,126 @@ | ||
import numpy as np | ||
from polara.recommender.models import RecommenderModel | ||
|
||
from polara import SVDModel | ||
from polara.recommender.models import RecommenderModel, ScaledSVD | ||
from polara.lib.similarity import stack_features | ||
from polara.lib.sparse import sparse_dot | ||
|
||
class ContentBasedColdStart(RecommenderModel): | ||
|
||
class ItemColdStartEvaluationMixin: | ||
def __init__(self, *args, **kwargs): | ||
super(ContentBasedColdStart, self).__init__(*args, **kwargs) | ||
self.method = 'CB' | ||
super().__init__(*args, **kwargs) | ||
self.filter_seen = False # there are no seen entities in cold start | ||
self._prediction_key = '{}_cold'.format(self.data.fields.itemid) | ||
self._prediction_target = self.data.fields.userid | ||
|
||
|
||
class RandomModelItemColdStart(ItemColdStartEvaluationMixin, RecommenderModel): | ||
def __init__(self, *args, **kwargs): | ||
self.seed = kwargs.pop('seed', None) | ||
super().__init__(*args, **kwargs) | ||
self.method = 'RND(cs)' | ||
|
||
def build(self): | ||
seed = self.seed | ||
self._random_state = np.random.RandomState(seed) if seed is not None else np.random | ||
|
||
def get_recommendations(self): | ||
repr_users = self.data.representative_users | ||
if repr_users is None: | ||
repr_users = self.data.index.userid.training | ||
repr_users = repr_users.new.values | ||
n_cold_items = self.data.index.itemid.cold_start.shape[0] | ||
shape = (n_cold_items, len(repr_users)) | ||
users_matrix = np.lib.stride_tricks.as_strided(repr_users, shape, | ||
(0, repr_users.itemsize)) | ||
random_users = np.apply_along_axis(self._random_state.choice, 1, | ||
users_matrix, self.topk, replace=False) | ||
return random_users | ||
|
||
|
||
class PopularityModelItemColdStart(ItemColdStartEvaluationMixin, RecommenderModel): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method = 'MP(cs)' | ||
|
||
def build(self): | ||
userid = self.data.fields.userid | ||
user_activity = self.data.training[userid].value_counts(sort=False) | ||
repr_users = self.data.representative_users | ||
if repr_users is not None: | ||
user_activity = user_activity.reindex(repr_users.new.values) | ||
self.user_scores = user_activity.sort_values(ascending=False) | ||
|
||
def get_recommendations(self): | ||
topk = self.topk | ||
shape = (self.data.index.itemid.cold_start.shape[0], topk) | ||
top_users = self.user_scores.index[:topk].values | ||
top_users_array = np.lib.stride_tricks.as_strided(top_users, shape, | ||
(0, top_users.itemsize)) | ||
return top_users_array | ||
|
||
|
||
class SimilarityAggregationItemColdStart(ItemColdStartEvaluationMixin, RecommenderModel): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method = 'SIM(cs)' | ||
self.implicit = False | ||
self.dense_output = False | ||
|
||
def build(self): | ||
pass | ||
|
||
def get_recommendations(self): | ||
item_similarity_scores = self.data.cold_items_similarity | ||
|
||
user_item_matrix = self.get_training_matrix() | ||
user_item_matrix.data = np.ones_like(user_item_matrix.data) | ||
if self.implicit: | ||
user_item_matrix.data = np.ones_like(user_item_matrix.data) | ||
scores = sparse_dot(item_similarity_scores, user_item_matrix, self.dense_output, True) | ||
top_similar_users = self.get_topk_elements(scores).astype(np.intp) | ||
return top_similar_users | ||
|
||
scores = item_similarity_scores.dot(user_item_matrix.T).tocsr() | ||
|
||
class SVDModelItemColdStart(ItemColdStartEvaluationMixin, SVDModel): | ||
def __init__(self, *args, item_features=None, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method = 'PureSVD(cs)' | ||
self.item_features = item_features | ||
self.use_raw_features = item_features is not None | ||
|
||
def build(self, *args, **kwargs): | ||
super().build(*args, return_factors=True, **kwargs) | ||
|
||
def get_recommendations(self): | ||
userid = self.data.fields.userid | ||
itemid = self.data.fields.itemid | ||
|
||
u = self.factors[userid] | ||
v = self.factors[itemid] | ||
s = self.factors['singular_values'] | ||
|
||
if self.use_raw_features: | ||
item_info = self.item_features.reindex(self.data.index.itemid.training.old.values, | ||
fill_value=[]) | ||
item_features, feature_labels = stack_features(item_info, normalize=False) | ||
w = item_features.T.dot(v).T | ||
wwt_inv = np.linalg.pinv(w @ w.T) | ||
|
||
cold_info = self.item_features.reindex(self.data.index.itemid.cold_start.old.values, | ||
fill_value=[]) | ||
cold_item_features, _ = stack_features(cold_info, labels=feature_labels, normalize=False) | ||
else: | ||
w = self.data.item_relations.T.dot(v).T | ||
wwt_inv = np.linalg.pinv(w @ w.T) | ||
cold_item_features = self.data.cold_items_similarity | ||
|
||
cold_items_factors = cold_item_features.dot(w.T) @ wwt_inv | ||
scores = cold_items_factors @ (u * s[None, :]).T | ||
top_similar_users = self.get_topk_elements(scores).astype(np.intp) | ||
return top_similar_users | ||
|
||
|
||
class ScaledSVDItemColdStart(ScaledSVD, SVDModelItemColdStart): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method = 'PureSVDs(cs)' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from polara.recommender.data import RecommenderData | ||
|
||
|
||
class ItemPostFilteringData(RecommenderData): | ||
def __init__(self, *args, item_context_mapping=None, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
userid = self.fields.userid | ||
itemid = self.fields.itemid | ||
self.item_context_mapping = dict(**item_context_mapping) | ||
self.context_data = {context: dict.fromkeys([userid, itemid]) | ||
for context in item_context_mapping.keys()} | ||
|
||
def map_context_data(self, context): | ||
if context is None: | ||
return | ||
|
||
userid = self.fields.userid | ||
itemid = self.fields.itemid | ||
|
||
context_mapping = self.item_context_mapping[context] | ||
index_mapping = self.index.itemid.set_index('old').new | ||
mapped_index = {itemid: lambda x: x[itemid].map(index_mapping)} | ||
item_data = (context_mapping.loc[lambda x: x[itemid].isin(index_mapping.index)] | ||
.assign(**mapped_index) | ||
.groupby(context)[itemid] | ||
.apply(list)) | ||
holdout = self.test.holdout | ||
try: | ||
user_data = holdout.set_index(userid)[context] | ||
except AttributeError: | ||
print(f'Unable to map {context}: holdout data is not recognized') | ||
return | ||
except KeyError: | ||
print(f'Unable to map {context}: not present in holdout') | ||
return | ||
# deal with mesmiatch between user and item data | ||
item_data = item_data.reindex(user_data.drop_duplicates().values, fill_value=[]) | ||
|
||
self.context_data[context][userid] = user_data | ||
self.context_data[context][itemid] = item_data | ||
|
||
def update_contextual_data(self): | ||
holdout = self.test.holdout | ||
if holdout is not None: | ||
# assuming that for each user in holdout we have only 1 item | ||
assert holdout.shape[0] == holdout[self.fields.userid].nunique() | ||
|
||
for context in self.item_context_mapping.keys(): | ||
self.map_context_data(context) | ||
|
||
def prepare(self, *args, **kwargs): | ||
super().prepare(*args, **kwargs) | ||
self.update_contextual_data() | ||
|
||
|
||
def set_test_data(self, *args, **kwargs): | ||
super().set_test_data(*args, **kwargs) | ||
self.update_contextual_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy as np | ||
|
||
|
||
class ItemPostFilteringMixin: | ||
def upvote_context_items(self, context, scores, test_users): | ||
if context is None: | ||
return | ||
|
||
userid = self.data.fields.userid | ||
itemid = self.data.fields.itemid | ||
context_data = self.data.context_data[context] | ||
try: | ||
upvote_items = context_data[userid].loc[test_users].map(context_data[itemid]) | ||
except: | ||
print(f'Unable to upvote items in context "{context}"') | ||
return | ||
upvote_index = zip(*[(i, el) for i, l in enumerate(upvote_items) for el in l]) | ||
|
||
context_idx_flat = np.ravel_multi_index(list(upvote_index), scores.shape) | ||
context_scores = scores.flat[context_idx_flat] | ||
|
||
upscored = scores.max() + context_scores + 1 | ||
scores.flat[context_idx_flat] = upscored | ||
|
||
def upvote_relevant_items(self, scores, test_users): | ||
for context in self.data.context_data.keys(): | ||
self.upvote_context_items(context, scores, test_users) | ||
|
||
def slice_recommendations(self, test_data, test_shape, start, stop, test_users): | ||
scores, slice_data = super().slice_recommendations(test_data, test_shape, start, stop, test_users) | ||
self.upvote_relevant_items(scores, test_users[start:stop]) | ||
return scores, slice_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.sparse import issparse | ||
|
||
from polara.recommender.data import RecommenderData | ||
|
||
|
||
class SideRelationsMixin: | ||
def __init__(self, rel_mat, rel_idx, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
entities = [self.fields.userid, self.fields.itemid] | ||
self._rel_idx = {entity: pd.Series(index=idx, data=np.arange(len(idx)), copy=False) | ||
if idx is not None else None | ||
for entity, idx in rel_idx.items() | ||
if entity in entities} | ||
self._rel_mat = {entity: mat for entity, mat in rel_mat.items() if entity in entities} | ||
self._relations = dict.fromkeys(entities) | ||
|
||
self.subscribe(self.on_change_event, self._clean_relations) | ||
|
||
def _clean_relations(self): | ||
self._relations = dict.fromkeys(self._relations.keys()) | ||
|
||
@property | ||
def item_relations(self): | ||
entity = self.fields.itemid | ||
return self.get_relations_matrix(entity) | ||
|
||
@property | ||
def user_relations(self): | ||
entity = self.fields.userid | ||
return self.get_relations_matrix(entity) | ||
|
||
def get_relations_matrix(self, entity): | ||
relations = self._relations.get(entity, None) | ||
if relations is None: | ||
self._update_relations(entity) | ||
return self._relations[entity] | ||
|
||
def _update_relations(self, entity): | ||
rel_mat = self._rel_mat[entity] | ||
if rel_mat is None: | ||
self._relations[entity] = None | ||
else: | ||
if self.verbose: | ||
print(f'Updating {entity} relations matrix') | ||
|
||
index_data = self.get_entity_index(entity) | ||
entity_idx = index_data['old'] | ||
|
||
rel_idx = entity_idx.map(self._rel_idx[entity]).values | ||
rel_mat = self._rel_mat[entity][:, rel_idx][rel_idx, :] | ||
|
||
self._relations[entity] = rel_mat | ||
|
||
|
||
class IdentityDiagonalMixin: | ||
def _update_relations(self, *args, **kwargs): | ||
super()._update_relations(*args, **kwargs) | ||
for rel_mat in self._relations.values(): | ||
if rel_mat is not None: | ||
if issparse(rel_mat): | ||
rel_mat.setdiag(1) | ||
else: | ||
np.fill_diagonal(rel_mat, 1) | ||
|
||
|
||
class SimilarityDataModel(IdentityDiagonalMixin, SideRelationsMixin, RecommenderData): pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import scipy as sp | ||
import numpy as np | ||
|
||
from polara.recommender.models import RecommenderModel, ProbabilisticMF | ||
from polara.lib.optimize import kernelized_pmf_sgd | ||
from polara.lib.sparse import sparse_dot | ||
from polara.tools.timing import track_time | ||
|
||
|
||
class SimilarityAggregation(RecommenderModel): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method = 'SIM' | ||
self.implicit = False | ||
self.dense_output = False | ||
self.item_similarity_matrix = False | ||
|
||
def build(self): | ||
# use copy to prevent contaminating original data | ||
self.item_similarity_matrix = self.data.item_relations.copy() | ||
self.item_similarity_matrix.setdiag(0) # exclude self-links | ||
self.item_similarity_matrix.eliminate_zeros() | ||
|
||
def slice_recommendations(self, test_data, shape, start, stop, test_users=None): | ||
test_matrix, slice_data = self.get_test_matrix(test_data, shape, (start, stop)) | ||
if self.implicit: | ||
test_matrix.data = np.ones_like(test_matrix.data) | ||
scores = sparse_dot(test_matrix, self.item_similarity_matrix, self.dense_output, True) | ||
return scores, slice_data | ||
|
||
|
||
class KernelizedRecommenderMixin: | ||
'''Based on the work: | ||
Kernelized Probabilistic Matrix Factorization: Exploiting Graphs and Side Information | ||
http://people.ee.duke.edu/~lcarin/kpmf_sdm_final.pdf | ||
''' | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.kernel_type = 'reg' | ||
self.beta = 0.01 | ||
self.gamma = 0.1 | ||
self.sigma = 1 | ||
self.kernel_update = None # will use default kernel update method | ||
self.sparse_kernel_format = True | ||
|
||
entities = [self.data.fields.userid, self.data.fields.itemid] | ||
self.factor_sigma = dict.fromkeys(entities, 1) | ||
self._kernel_matrices = dict.fromkeys(entities) | ||
|
||
self.data.subscribe(self.data.on_change_event, self._clean_kernel_data) | ||
|
||
def _compute_kernel(self, laplacian, kernel_type=None): | ||
kernel_type = kernel_type or self.kernel_type | ||
if kernel_type == 'dif': # diffusion | ||
return sp.sparse.linalg.expm(self.beta * laplacian) # dense matrix | ||
elif kernel_type == 'reg': # regularized laplacian | ||
n_entities = laplacian.shape[0] | ||
return sp.sparse.eye(n_entities).tocsr() + self.gamma * laplacian # sparse matrix | ||
else: | ||
raise ValueError | ||
|
||
def _update_kernel_matrices(self, entity): | ||
laplacian = self.data.get_relations_matrix(entity) | ||
if laplacian is None: | ||
sigma = self.factor_sigma[entity] | ||
n_entities = self.data.get_entity_index(entity).shape[0] | ||
kernel_matrix = (sigma**2) * sp.sparse.eye(n_entities).tocsr() | ||
else: | ||
kernel_matrix = self._compute_kernel(laplacian) | ||
self._kernel_matrices[entity] = kernel_matrix | ||
|
||
def _clean_kernel_data(self): | ||
self._kernel_matrices = dict.fromkeys(self._kernel_matrices.keys()) | ||
|
||
@property | ||
def item_kernel_matrix(self): | ||
entity = self.data.fields.itemid | ||
return self.get_kernel_matrix(entity) | ||
|
||
@property | ||
def user_kernel_matrix(self): | ||
entity = self.data.fields.userid | ||
return self.get_kernel_matrix(entity) | ||
|
||
def get_kernel_matrix(self, entity): | ||
kernel_matrix = self._kernel_matrices.get(entity, None) | ||
if kernel_matrix is None: | ||
self._update_kernel_matrices(entity) | ||
return self._kernel_matrices[entity] | ||
|
||
|
||
class KernelizedPMF(KernelizedRecommenderMixin, ProbabilisticMF): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.optimizer = kernelized_pmf_sgd | ||
self.method = 'KPMF' | ||
|
||
def build(self, *args, **kwargs): | ||
kernel_matrices = (self.user_kernel_matrix, self.item_kernel_matrix) | ||
kernel_config = dict(kernel_update=self.kernel_update, | ||
sparse_kernel_format=self.sparse_kernel_format) | ||
super().build(kernel_matrices, *args, **kernel_config, **kwargs) |
Oops, something went wrong.