-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimpleTest_hessian.py
37 lines (24 loc) · 979 Bytes
/
simpleTest_hessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from rbmLib.rbm import *
from rbmLib.training import *
if __name__ == "__main__":
rbm = RBM(4, 2)
# try:
# rbm.load("simpleHessian.rbm")
# except:
# pass
testData = np.array([[0,0,0,0], [1,1,0,0], [0,0,1,1], [1,1,1,1], [1,1,1,1]])
for _ in range(11):
testData = np.vstack((testData, testData))
np.random.shuffle(testData)
print(len(testData))
trainer = RBMTrainerPCDHessian()
trainer.train(rbm, testData, learningRate=1, learningRateDecay=-1e-1, nMarkovChains=400, nMarkovIter=5,
epochs=1000, miniBatchSize=None, convergenceThreshold=1e-3, autosave=False)
rbm.save("simpleHessian.rbm")
print(rbm)
sampler = gibbsSampler(rbm)
nMarkovTest = 10000
visibles, _ = sampler.sample(nMarkovChains=nMarkovTest, nMarkovIter=1000)
unique_elements, counts_elements = np.unique(visibles.T, axis=0, return_counts=True)
for element, count in zip(unique_elements, counts_elements):
print(element, count/nMarkovTest)