-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtwenty_newsgroups_autoencoder.py
116 lines (95 loc) · 3.31 KB
/
twenty_newsgroups_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from collections import Counter
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import TensorDataset
from sklearn.datasets import fetch_20newsgroups, fetch_20newsgroups_vectorized
from sklearn.svm import LinearSVC
from sklearn.feature_selection import SelectFromModel
# from sklearn.feature_extraction.text import TfidfVectorizer
from lib.common import FirstLayerSparseDecoder
print('Loading 20 newsgroups...')
# newsgroups = fetch_20newsgroups(remove=('headers', 'footers', 'quotes'))
newsgroups = fetch_20newsgroups_vectorized(subset='all', remove=('headers', 'footers', 'quotes'))
ix = np.argsort(newsgroups.target)
Xfat = newsgroups.data[ix]
y = newsgroups.target[ix]
group_counter = Counter(y)
group_counts = [group_counter[i] for i in range(20)]
print('Fitting SVM for TF-IDF feature selection...')
# Smaller C = stronger sparsity penalty
lsvc = LinearSVC(loss='squared_hinge', penalty='l1', dual=False, C=0.01).fit(Xfat, y)
X_sel = SelectFromModel(lsvc, threshold=1e-3, prefit=True).transform(Xfat)
X = torch.from_numpy(X_sel.astype(np.float32).T.toarray())
print(X.size())
data_loader = torch.utils.data.DataLoader(
TensorDataset(X, torch.arange(X.size(0))),
batch_size=X.size(0),
shuffle=True
)
torch.manual_seed(0)
dim_z = 64
num_epochs = 100000
lam = 0.0
encoder = torch.nn.Linear(X.size(1), dim_z)
def make_linear_decoder():
group_input_dim = dim_z
return FirstLayerSparseDecoder(
[torch.nn.Linear(group_input_dim, i, bias=False) for i in group_counts],
[group_input_dim] * 20,
dim_z
)
def make_nonlinear_decoder():
group_input_dim = 1
return FirstLayerSparseDecoder(
[
torch.nn.Sequential(
torch.nn.Linear(group_input_dim, i),
torch.nn.Tanh(),
torch.nn.Linear(i, i),
torch.nn.Tanh(),
torch.nn.Linear(i, i)
)
for i in group_counts
],
[group_input_dim] * 20,
dim_z
)
decoder = make_linear_decoder()
# decoder = make_nonlinear_decoder()
# lr = 1e-4
# optimizer = torch.optim.Adam([
# {'params': encoder.parameters(), 'lr': lr},
# {'params': decoder.parameters(), 'lr': lr}
# ])
lr = 1e-4
momentum = 0.9
optimizer = torch.optim.SGD([
{'params': encoder.parameters(), 'lr': lr, 'momentum': momentum},
{'params': decoder.parameters(), 'lr': lr, 'momentum': momentum}
])
# Calculate the reconstruction loss on the given data
def reconstruction_loss(data):
Xvar = Variable(data)
reconstructed = decoder(encoder(Xvar))
if torch.sum(torch.abs(reconstructed - reconstructed[0])).data[0] / Xvar.size(0) <= 1e-3:
print('solution has collapsed!')
residual = reconstructed - Xvar
return torch.sum(torch.pow(residual, 2)) / Xvar.size(0)
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(data_loader):
train_loss = reconstruction_loss(data)
sparsity_penalty = lam * decoder.group_lasso_penalty()
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
decoder.proximal_step(lr * lam)
print('epoch: {} [{}/{} ({:.0f}%)]'.format(
epoch,
batch_idx * len(data),
len(data_loader.dataset),
100. * batch_idx / len(data_loader)
))
print(' reconstruction loss:', train_loss.data[0])
print(' regularization: ', sparsity_penalty.data[0])
print(' combined loss: ', (train_loss + sparsity_penalty).data[0])