-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmarginal_mcsat.py
143 lines (107 loc) · 4.38 KB
/
marginal_mcsat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#script(python)
import clingo
import random
import sympy
import time
import xor_constraint_drawer
class mcSAT(object):
def __init__(self,content,evidence,queryList,xorMode,max_liter=500):
warn_option = "--warn=none"
thread_option = "-t 4"
self.clingoOptions = [warn_option, thread_option]
self.max_num_iteration = max_liter
self.curr_sample = None
self.whole_model = []
self.query_count = {}
self.domain = []
self.aspContent = content
self.eviContent = evidence
self.queryList = queryList
self.sampleForReturn = []
self.xorM = xorMode
def getSample(self,model):
whole_model = []
for atom in model:
whole_model.append(atom)
def findUnsatRules(self,atoms):
M = []
for atom in atoms:
if atom.name.startswith('unsat'):
weight = float(str(atom.arguments[1]).replace("\"", ""))
r = random.random()
if r < 1 - sympy.exp(weight):
M.append(atom)
return M
def processSample(self,atoms):
sample_attempt = []
# Find rules that are not satisfied
M = self.findUnsatRules(atoms)
# Do specific things with the sample: counting atom occurence
self.sampleForReturn.append(atoms)
for r in self.domain:
if r in atoms:
sample_attempt.append((r, True))
if r in self.query_count:
self.query_count[r] += 1
else:
sample_attempt.append((r, False))
return M,sample_attempt
def runMCASP(self):
# Configure Clingo running options
firstSamplcontrol = clingo.Control(self.clingoOptions)
firstSamplcontrol.add("base",[],self.aspContent)
firstSamplcontrol.ground([("base", [])])
if self.eviContent != "":
firstSamplcontrol.add("evid",[],self.eviContent)
firstSamplcontrol.ground([("evid", [])])
for atom in firstSamplcontrol.symbolic_atoms:
if atom.symbol.name in self.queryList:
self.query_count[atom.symbol] = 0
self.domain.append(atom.symbol)
random.seed()
sample_count = 0
models = []
firstSamplcontrol.solve([], lambda model: models.append(model.symbols(atoms=True)))
if len(models) >= 1:
# randomly generate a index from models
randomIndex = random.randint(0, len(models) - 1)
model = models[randomIndex]
else:
print("Program has no satisfiable solution, exit!")
return False
M,curr_sample = self.processSample(model)
for _ in range(1, self.max_num_iteration):
if sample_count % 10 == 0:
print("Getting sample ", sample_count)
sample_count += 1
# Create file with satisfaction constraints
constraintContent = ""
for m in M:
argsStr = ''
for arg in m.arguments:
argsStr += (str(arg) + ',')
argsStr = argsStr.rstrip(',')
constraintContent+=':- not ' + m.name + '(' + argsStr + ').\n'
#startTime = time.time()
if self.eviContent != "":
xorSampler = xor_constraint_drawer.xorSampler(self.xorM,[self.aspContent, self.eviContent, constraintContent],self.clingoOptions)
models = xorSampler.startDrawSample()
else:
xorSampler = xor_constraint_drawer.xorSampler(self.xorM,[self.aspContent, constraintContent], self.clingoOptions)
models = xorSampler.startDrawSample()
#print("MCASP time for getting 1 sample: ", str(time.time() - startTime))
if len(models) > 1:
# randomly generate a index from models
randomIndex = random.randint(0, len(models) - 1)
model = models[randomIndex]
else:
model = models[0]
M, curr_sample = self.processSample(model)
return True
def printQuery(self):
for atom in self.query_count:
print(atom, ": ", float(self.query_count[atom]) / float(self.max_num_iteration))
print(self.query_count[atom])
print(self.max_num_iteration)
def getSamples(self):
return self.sampleForReturn