Skip to content

Commit

Permalink
Python3 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
rdlester committed Jul 24, 2015
1 parent 63435d5 commit 2c04361
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 109 deletions.
2 changes: 1 addition & 1 deletion README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Python implementation of Sum-product (aka Belief-Propagation) for discrete Facto

See [this paper](http://www.comm.utoronto.ca/frank/papers/KFL01.pdf) for more details on the Factor Graph framework and the sum-product algorithm. This code was originally written as part of a grad student seminar taught by Erik Sudderth at Brown University; the [seminar web page](http://cs.brown.edu/courses/csci2420/) is an excellent resource for learning more about graphical models.

Requires NumPy.
Requires NumPy and future.

To use:

Expand Down
83 changes: 43 additions & 40 deletions graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Graph class
from __future__ import print_function
from builtins import range
from future.utils import iteritems
import numpy as np
from node import FacNode, VarNode
import pdb
Expand All @@ -8,51 +11,51 @@
Central difference: nbrs stored as references, not ids
(makes message propagation easier)
"""

class Graph:
""" Putting everything together
"""

def __init__(self):
self.var = {}
self.fac = []
self.dims = []
self.converged = False

def addVarNode(self, name, dim):
newId = len(self.var)
newVar = VarNode(name, dim, newId)
self.var[name] = newVar
self.dims.append(dim)

return newVar

def addFacNode(self, P, *args):
newId = len(self.fac)
newFac = FacNode(P, newId, *args)
self.fac.append(newFac)

return newFac

def disableAll(self):
""" Disable all nodes in graph
Useful for switching on small subnetworks
of bayesian nets
"""
for k, v in self.var.iteritems():
for k, v in iteritems(self.var):
v.disable()
for f in self.fac:
f.disable()

def reset(self):
""" Reset messages to original state
"""
for k, v in self.var.iteritems():
for k, v in iteritems(self.var):
v.reset()
for f in self.fac:
f.reset()
self.converged = False

def sumProduct(self, maxsteps=500):
""" This is the algorithm!
Each timestep:
Expand All @@ -64,62 +67,62 @@ def sumProduct(self, maxsteps=500):
timestep = 0
while timestep < maxsteps and not self.converged: # run for maxsteps cycles
timestep = timestep + 1
print timestep
print(timestep)

for f in self.fac:
# start with factor-to-variable
# can send immediately since not sending to any other factors
f.prepMessages()
f.sendMessages()
for k, v in self.var.iteritems():

for k, v in iteritems(self.var):
# variable-to-factor
v.prepMessages()
v.sendMessages()

# check for convergence
t = True
for k, v in self.var.iteritems():
for k, v in iteritems(self.var):
t = t and v.checkConvergence()
if not t:
break
if t:
if t:
for f in self.fac:
t = t and f.checkConvergence()
if not t:
break

if t: # we have convergence!
self.converged = True

# if run for 500 steps and still no convergence:impor
if not self.converged:
print "No convergence!"
print("No convergence!")

def marginals(self, maxsteps=500):
""" Return dictionary of all marginal distributions
indexed by corresponding variable name
"""
# Message pass
self.sumProduct(maxsteps)

marginals = {}
# for each var
for k, v in self.var.iteritems():
for k, v in iteritems(self.var):
if v.enabled: # only include enabled variables
# multiply together messages
vmarg = 1
for i in xrange(0, len(v.incoming)):
for i in range(0, len(v.incoming)):
vmarg = vmarg * v.incoming[i]

# normalize
n = np.sum(vmarg)
vmarg = vmarg / n

marginals[k] = vmarg

return marginals

def bruteForce(self):
""" Brute force method. Only here for completeness.
Don't use unless you want your code to take forever to produce results.
Expand All @@ -132,7 +135,7 @@ def bruteForce(self):
enabledNids = []
enabledNames = []
enabledObserved = []
for k, v in self.var.iteritems():
for k, v in iteritems(self.var):
if v.enabled:
enabledNids.append(v.nid)
enabledNames.append(k)
Expand All @@ -141,17 +144,17 @@ def bruteForce(self):
enabledDims.append(v.dim)
else:
enabledDims.append(1)

# initialize matrix over all joint configurations
joint = np.zeros(enabledDims)

# loop over all configurations
self.configurationLoop(joint, enabledNids, enabledObserved, [])

# normalize
joint = joint / np.sum(joint)
return {'joint': joint, 'names': enabledNames}

def configurationLoop(self, joint, enabledNids, enabledObserved, currentState):
""" Recursive loop over all configurations
Used for brute force computation
Expand All @@ -164,7 +167,7 @@ def configurationLoop(self, joint, enabledNids, enabledObserved, currentState):
if currVar != len(enabledNids):
# need to continue assembling current configuration
if enabledObserved[currVar] < 0:
for i in xrange(0,joint.shape[currVar]):
for i in range(0,joint.shape[currVar]):
# add new variable value to state
currentState.append(i)
self.configurationLoop(joint, enabledNids, enabledObserved, currentState)
Expand All @@ -175,7 +178,7 @@ def configurationLoop(self, joint, enabledNids, enabledObserved, currentState):
currentState.append(enabledObserved[currVar])
self.configurationLoop(joint, enabledNids, enabledObserved, currentState)
currentState.pop()

else:
# compute value for current configuration
potential = 1.
Expand All @@ -184,18 +187,18 @@ def configurationLoop(self, joint, enabledNids, enabledObserved, currentState):
# figure out which vars are part of factor
# then get current values of those vars in correct order
args = [currentState[enabledNids.index(x.nid)] for x in f.nbrs]

# get value and multiply in
potential = potential * f.P[tuple(args)]

# now add it to joint after correcting state for observed nodes
ind = [currentState[i] if enabledObserved[i] < 0 else 0 for i in range(0, currVar)]
joint[tuple(ind)] = potential

def marginalizeBrute(self, brute, var):
""" Util for marginalizing over joint configuration arrays produced by bruteForce
"""
sumout = range(0, len(brute['names']))
sumout = list(range(0, len(brute['names'])))
del sumout[brute['names'].index(var)]
marg = np.sum(brute['joint'], tuple(sumout))
return marg / np.sum(marg) # normalize to sum to one
Loading

0 comments on commit 2c04361

Please sign in to comment.