Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple GP example causes RecursionError #111

Closed
bmorris3 opened this issue Sep 7, 2020 · 7 comments
Closed

Simple GP example causes RecursionError #111

bmorris3 opened this issue Sep 7, 2020 · 7 comments
Labels
bug Something isn't working

Comments

@bmorris3
Copy link

bmorris3 commented Sep 7, 2020

When I define a custom mean model via exoplanet's GP object, I get a recursion depth exceeded error that I don't fully understand. This is almost certainly my misunderstanding of the proper mean model definition for exoplanet but I can't figure out what's wrong. Any tips? Thanks!

import numpy as np
import pymc3 as pm
from exoplanet.gp import GP, terms

# Generate some fake data with a heaviside function
x = np.linspace(-10, 10)
x0 = 2
amp = 10
noise = 0.1
y = np.heaviside(x - x0, amp) + noise * np.random.randn(*x.shape)
yerr = noise * np.ones_like(y)


# define a mean model with two free parameters
class MeanModel:
    def __init__(self):
        self.x0 = pm.Uniform('x0', lower=-10, upper=10)
        self.amp = pm.Uniform('amp', lower=0, upper=100)
        
    def __call__(self, X): 
        return pm.math.where(pm.math.ge(X - self.x0, 0), 
                             self.amp, 0)
        
# define the GP marginal likelihood
with pm.Model() as model:
    mean = MeanModel()
    gp = GP(terms.Matern32Term(sigma=noise, rho=20), x, yerr**2, mean=mean)
    gp.marginal('gp', observed=y)
    
# find MAP solution (this works)
with model: 
    map_soln = pm.find_MAP()

# run NUTS (this doesn't work)
with model: 
    trace = pm.sample(draws=100, tune=100, init='jitter+adapt_full')

Should run without error, but I'm getting a very long traceback starting with

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-159-836b4a29cc91> in <module>
      1 with model:
----> 2     trace = pm.sample(draws=100, tune=100, init='jitter+adapt_full')

~/miniconda3/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    538         _print_step_hierarchy(step)
    539         try:
--> 540             trace = _mp_sample(**sample_args, **parallel_args)
    541         except pickle.PickleError:
    542             _log.warning("Could not pickle model, sampling singlethreaded.")

~/miniconda3/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1454         progressbar,
   1455         mp_ctx=mp_ctx,
-> 1456         pickle_backend=pickle_backend,
   1457     )
   1458     try:

~/miniconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar, mp_ctx, pickle_backend)
    430         if mp_ctx.get_start_method() != 'fork':
    431             if pickle_backend == 'pickle':
--> 432                 step_method_pickled = pickle.dumps(step_method, protocol=-1)
    433             elif pickle_backend == 'dill':
    434                 try:

~/miniconda3/lib/python3.7/site-packages/pymc3/distributions/distribution.py in __getstate__(self)
    425         # always defined in the notebook and won't be pickled correctly.
    426         # Fix https://github.com/pymc-devs/pymc3/issues/3844
--> 427         logp = dill.dumps(self.logp)
    428         vals = self.__dict__.copy()
    429         vals['logp'] = logp

~/miniconda3/lib/python3.7/site-packages/dill/_dill.py in dumps(obj, protocol, byref, fmode, recurse, **kwds)
    263     """pickle an object to a string"""
    264     file = StringIO()
--> 265     dump(obj, file, protocol, byref, fmode, recurse, **kwds)#, strictio)
    266     return file.getvalue()
    267 

~/miniconda3/lib/python3.7/site-packages/dill/_dill.py in dump(obj, file, protocol, byref, fmode, recurse, **kwds)
    257     _kwds = kwds.copy()
    258     _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse))
--> 259     Pickler(file, protocol, **_kwds).dump(obj)
    260     return
    261 

~/miniconda3/lib/python3.7/site-packages/dill/_dill.py in dump(self, obj)
    444             raise PicklingError(msg)
    445         else:
--> 446             StockPickler.dump(self, obj)
    447         stack.clear()  # clear record of 'recursion-sensitive' pickled objects
    448         return

~/miniconda3/lib/python3.7/pickle.py in dump(self, obj)
    435         if self.proto >= 4:
    436             self.framer.start_framing()
--> 437         self.save(obj)
    438         self.write(STOP)
    439         self.framer.end_framing()

~/miniconda3/lib/python3.7/pickle.py in save(self, obj, save_persistent_id)
    502         f = self.dispatch.get(t)
    503         if f is not None:
--> 504             f(self, obj) # Call unbound method with explicit self
    505             return

and ending in:

~/miniconda3/lib/python3.7/pickle.py in save_dict(self, obj)
    857 
    858         self.memoize(obj)
--> 859         self._batch_setitems(obj.items())
    860 
    861     dispatch[dict] = save_dict

~/miniconda3/lib/python3.7/pickle.py in _batch_setitems(self, items)
    883                 for k, v in tmp:
    884                     save(k)
--> 885                     save(v)
    886                 write(SETITEMS)
    887             elif n:

~/miniconda3/lib/python3.7/pickle.py in save(self, obj, save_persistent_id)
    522             reduce = getattr(obj, "__reduce_ex__", None)
    523             if reduce is not None:
--> 524                 rv = reduce(self.proto)
    525             else:
    526                 reduce = getattr(obj, "__reduce__", None)

... last 218 frames repeated, from the frame below ...

~/miniconda3/lib/python3.7/site-packages/pymc3/distributions/distribution.py in __getstate__(self)
    425         # always defined in the notebook and won't be pickled correctly.
    426         # Fix https://github.com/pymc-devs/pymc3/issues/3844
--> 427         logp = dill.dumps(self.logp)
    428         vals = self.__dict__.copy()
    429         vals['logp'] = logp

RecursionError: maximum recursion depth exceeded in comparison

Setup

  • Version of exoplanet: 0.3.2
  • Operating system: Mac OS X 10.11.6
  • Python version & installation method (pip, conda, etc.): miniconda 3.7

Update: fixed some silly parameters for the fake data, still getting the same error.

@bmorris3 bmorris3 added the bug Something isn't working label Sep 7, 2020
@bmorris3
Copy link
Author

bmorris3 commented Sep 7, 2020

Continuing to investigate, seems like @dfm has seen a similar pickling-related issue before in pymc-devs/pymc#3844. The suggested fix there is:

import multiprocessing as mp
mp.set_start_method("fork")

but that doesn't seem to help.

@bmorris3
Copy link
Author

bmorris3 commented Sep 7, 2020

This does go away if pm.sample(..., cores=1). Hmm!

@dfm
Copy link
Member

dfm commented Sep 7, 2020

As you say, this is really more of a PyMC3 bug (they are insistent on forkserver) since this is an unpickling error deep inside there. I was able to run your code as a script by wrapping the code in a if __name__ == "__main__": and including the following argument in sample:

mp_ctx=mp.get_context("fork")

You're going to have some pain trying to ever get this to work in a Jupyter notebook on Mac or Windows with recent versions of PyMC3, I think.

An alternative is to apply the mean function yourself, e.g. something like [untested]:

gp = GP(terms.Matern32Term(sigma=noise, rho=20), x, yerr**2)
gp.marginal('gp', observed=y - mean(x))

@bmorris3
Copy link
Author

bmorris3 commented Sep 7, 2020

Thanks so much for the quick help! Using your tips, I find that the following works in both a script, an ipython environment and a jupyter notebook:

%matplotlib inline
import numpy as np

import pymc3 as pm
from exoplanet.gp import GP, terms
from exoplanet import eval_in_model
import multiprocessing as mp

# Generate some fake data with a heaviside function
x = np.linspace(-10, 10)
true_x0 = 2
true_amp = 10
noise = 0.1
y = true_amp * np.heaviside(x - true_x0, 1) + noise * np.random.randn(*x.shape)
yerr = noise * np.ones_like(y)


class MeanModel:
    def __init__(self, x0, amp):
        self.x0 = x0
        self.amp = amp

    def __call__(self, X):
        return pm.math.where(pm.math.ge(X - self.x0, 0), self.amp, 0)


# define the GP marginal likelihood
with pm.Model() as model:
    x0 = pm.Uniform('x0', lower=-10, upper=10, testval=1.1 * true_x0)
    amp = pm.Uniform('amp', lower=0, upper=100, testval=0.8 * true_amp)

    mean = MeanModel(x0, amp)

    gp = GP(terms.Matern32Term(sigma=noise, rho=20), x, yerr ** 2)
    gp.marginal('gp', observed=y - mean(x))

# find MAP solution (this works)
with model:
    map_soln = pm.find_MAP()

with model:
    trace = pm.sample(start=map_soln,
                      draws=1000, tune=1000, init='jitter+adapt_full',
                      mp_ctx=mp.get_context("fork"))

# quick script to plot the results:
import matplotlib.pyplot as plt

plt.errorbar(x, y, yerr, fmt='.', color='k', ecolor='silver')

for i in np.random.randint(0, len(trace), size=10):
    with model:
        mu, var = eval_in_model(
            gp.predict(x, return_var=True), trace[i]
        )
        mean_eval = eval_in_model(
            mean(x), trace[i]
        )

    plt.fill_between(x, mu + np.sqrt(var) + mean_eval,
                     mu - np.sqrt(var) + mean_eval, alpha=0.2)
plt.show()

from corner import corner
corner(pm.trace_to_dataframe(trace))
plt.show()

To summarize, the key changes are passing:

mp_ctx=mp.get_context("fork")

to pm.sample and setting up the mean model with

    gp = GP(terms.Matern32Term(sigma=noise, rho=20), x, yerr ** 2)
    gp.marginal('gp', observed=y - mean(x))

@dfm
Copy link
Member

dfm commented Sep 7, 2020

Awesome! It's possible that you don't need the context of you subtract the mean manually. Did you try that?

@bmorris3
Copy link
Author

bmorris3 commented Sep 7, 2020

If I manually subtract the mean but leave out the mp_ctx kwarg I still get the RecursionError.

@dfm
Copy link
Member

dfm commented Sep 7, 2020

That surprises me - very strange!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants