Skip to content

Commit

Permalink
Fixing the MPI behaviour described in #38
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 30, 2012
1 parent f5ec9e1 commit a314fa7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
19 changes: 14 additions & 5 deletions emcee/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,29 @@ def map(self, function, tasks):
F = _function_wrapper(function)

# Tell all the workers what function to use.
requests = []
for i in range(self.size):
self.comm.isend(F, dest=i + 1)
r = self.comm.isend(F, dest=i + 1)
requests.append(r)

# Send all the tasks off. Do not wait for them to be received,
# just continue.
# Wait until all of the workers have responded. See:
# https://gist.github.com/4176241
MPI.Request.waitall(requests)

# Send all the tasks off and wait for them to be received.
# Again, see the bug in the above gist.
requests = []
for i, task in enumerate(tasks):
worker = i % self.size + 1
if self.debug:
print(u"Sent task {0} to worker {1} with tag {2}."
.format(task, worker, i))
self.comm.isend(task, dest=worker, tag=i)
results = []
r = self.comm.isend(task, dest=worker, tag=i)
requests.append(r)
MPI.Request.waitall(requests)

# Now wait for the answers.
results = []
for i in range(ntask):
worker = i % self.size + 1
if self.debug:
Expand Down
20 changes: 6 additions & 14 deletions examples/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
def lnprob(x):
return -0.5 * np.sum(x ** 2)

ndim = 10
nwalkers = 200
ndim = 50
nwalkers = 250
p0 = [np.random.rand(ndim) for i in xrange(nwalkers)]

# Initialize the MPI-based pool used for parallelization.
pool = MPIPool()

if not pool.is_master():
Expand All @@ -30,7 +32,7 @@ def lnprob(x):
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, pool=pool)

# Run 100 steps as a burn-in.
pos, prob, state = sampler.run_mcmc(p0, 1)
pos, prob, state = sampler.run_mcmc(p0, 100)

# Reset the chain to remove the burn-in samples.
sampler.reset()
Expand All @@ -43,16 +45,6 @@ def lnprob(x):
pool.close()

# Print out the mean acceptance fraction. In general, acceptance_fraction
# has an entry for each walker so, in this case, it is a 100-dimensional
# has an entry for each walker so, in this case, it is a 250-dimensional
# vector.
print(u"Mean acceptance fraction: ", np.mean(sampler.acceptance_fraction))

# Finally, you can plot the projected histograms of the samples using
# matplotlib as follows (as long as you have it installed).
try:
import matplotlib.pyplot as pl
except ImportError:
print(u"Try installing matplotlib to generate some sweet plots...")
else:
pl.hist(sampler.flatchain[:, 0], 100)
pl.show()
3 changes: 1 addition & 2 deletions examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def lnprob(x, mu, icov):
sampler.run_mcmc(pos, 1000, rstate0=state)

# Print out the mean acceptance fraction. In general, acceptance_fraction
# has an entry for each walker so, in this case, it is a 100-dimensional
# has an entry for each walker so, in this case, it is a 250-dimensional
# vector.
print("Mean acceptance fraction:", np.mean(sampler.acceptance_fraction))

Expand All @@ -67,4 +67,3 @@ def lnprob(x, mu, icov):
else:
pl.hist(sampler.flatchain[:,0], 100)
pl.show()

0 comments on commit a314fa7

Please sign in to comment.