Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce import time #438

Open
nspope opened this issue Nov 11, 2024 · 18 comments · May be fixed by #441
Open

Reduce import time #438

nspope opened this issue Nov 11, 2024 · 18 comments · May be fixed by #441

Comments

@nspope
Copy link
Contributor

nspope commented Nov 11, 2024

The import time (due to numba AOT compilation) is driving me nuts-- I wonder if there's a way we can build an extension optionally on install (e.g. triggered by a flag to pip). Jerome and others have had bad experiences with numba extensions, but 30 to 60 seconds for import really slows down the debug cycle, and if extension-building is turned off by default I don't think it'd be an issue.

@hyanwong
Copy link
Member

I agree it's a pain. I'll check with @jeromekelleher and @benjeffery if there is a sensible solution in this case.

@jeromekelleher
Copy link
Member

One simple thing is to turn caching on/off using an environment variable. So you turn on caching for dev, and it's off by default: https://github.com/sgkit-dev/sgkit/blob/main/sgkit/accelerate.py

Would this help?

@hyanwong
Copy link
Member

That sounds like a great suggestion. Thanks @jeromekelleher

@nspope
Copy link
Contributor Author

nspope commented Nov 13, 2024

That's a good suggestion @jeromekelleher and should be fine for development, thanks. I'm also pondering moving some of the core routines into a rust extension, in the long term.

@jeromekelleher
Copy link
Member

I'm also pondering moving some of the core routines into a rust extension, in the long term.

eek!

@nspope
Copy link
Contributor Author

nspope commented Nov 13, 2024

It's not that bad, I hope! What I'm trying to say is I'd like to move some of the low-level EP stuff into a compiled language extension, to save on import time, as this is just going to keep creeping up. Rust is great for numerical stuff and plugs in really nicely to Python.

(One of the motivations here is that we're using tsdate in simulation-based-inference contexts where we generate a bunch of simulations in parallel to train a neural network-- waiting >60 sec for import when running tsdate from the CLI is really a huge time sink in this application).

@hyanwong
Copy link
Member

Do you want to implement the environment variable caching, or do you want me to, @nspope ?

@jeromekelleher
Copy link
Member

Should definitely make caching an option for people who know what they're doing, 60s is horrendous.

If you paste in a profile I can comment on some possible ameliorations - that sounds much worse than it should be

@nspope
Copy link
Contributor Author

nspope commented Nov 13, 2024

Do you want to implement the environment variable caching

I'll play around with it a bit then ping you both. It's annoying but not super urgent.

hyanwong added a commit to hyanwong/tsdate that referenced this issue Dec 1, 2024
@hyanwong hyanwong linked a pull request Dec 1, 2024 that will close this issue
@hyanwong
Copy link
Member

hyanwong commented Dec 1, 2024

I tried the approach using in sgkit, in #441, and am getting lots of messages like this:

hypergeo.py:147: NumbaWarning: Cannot cache compiled function "_betaln" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @numba_jit("f8(f8, f8)")

I get this for the following 27 functions:

 hypergeo.py:147:  "_betaln"  
 hypergeo.py:209:  "_hyp2f1_unity"  
 hypergeo.py:235:  "_hyp2f1_laplace"  
 approx.py:143:  "approximate_gamma_iqr"  
 approx.py:251:  "moments"  
 approx.py:293:  "rootward_moments"  
 approx.py:335:  "leafward_moments"  
 approx.py:369:  "unphased_moments"  
 approx.py:411:  "twin_moments"  
 approx.py:428:  "sideways_moments"  
 approx.py:458:  "mutation_moments"  
 approx.py:500:  "mutation_rootward_moments"  
 approx.py:519:  "mutation_leafward_moments"  
 approx.py:538:  "mutation_unphased_moments"  
 approx.py:698:  "gamma_projection"  
 approx.py:725:  "leafward_projection"  
 approx.py:748:  "rootward_projection"  
 approx.py:771:  "unphased_projection"  
 approx.py:798:  "twin_projection"  
 approx.py:821:  "sideways_projection"  
 approx.py:844:  "mutation_gamma_projection"  
 approx.py:871:  "mutation_leafward_projection"  
 approx.py:895:  "mutation_rootward_projection"  
 approx.py:937:  "mutation_unphased_projection"  
 rescaling.py:448:  "piecewise_scale_posterior"  
 variational.py:274:  "propagate_likelihood"  
 variational.py:481:  "propagate_mutations" 

If I only cache the njit calls that numba doesn't complain about, I cut load time from about 44 to 27 seconds, which is OK, but still not perfect.

@nspope
Copy link
Contributor Author

nspope commented Dec 1, 2024

Yeah, I ran into the same issue-- it turns out all the EP stuff can't be cached (and that's the bulk of the compilation). Although I could probably drop the reliance on Scipy's Cython library (which is the root of the problem), I think the better way to go is to pull the EP stuff into a low-level extension.

@hyanwong
Copy link
Member

hyanwong commented Dec 1, 2024

I think numba/numba#6972 seems relevant. For instance, this comment:

When code like this is compiled:

@njit
def bar():
  pass

@njit
def foo():
    bar()

foo()

The function foo() "sees" bar as a global and the machine code for foo refers to bar by address, that address cannot be cached as it will change next time the program is run and so caching is disabled for function foo.

@nspope
Copy link
Contributor Author

nspope commented Dec 1, 2024

Hm ... it's not going to be possible to refactor things so as to avoid calling njit'd functions inside other njit'd functions (nor do I want to do this).

@benjeffery
Copy link
Member

Would doing something like:

from numba import njit, types

fntype = types.FunctionType(types.void())

@njit(cache=True)
def bar():
    pass

@njit(types.void(fntype), cache=True)
def foo(func):
    func()

foo(bar)

Be a simple refactor that would enable caching?

@nspope
Copy link
Contributor Author

nspope commented Jan 6, 2025

Thanks @benjeffery -- I'm not quite understanding, would this work because the signature is explicit? If so, that's already the case in tsdate (we're using all explicit signatures) and the reason it won't cache is because I'm using some ctypes globals to interface with scipy's special functions library here:

_dbl = ctypes.c_double
# gammaln
_gammaln_addr = get_cython_function_address("scipy.special.cython_special", "gammaln")
_gammaln_functype = ctypes.CFUNCTYPE(_dbl, _dbl)
_gammaln_f8 = _gammaln_functype(_gammaln_addr)
# gammainc
_gammainc_addr = get_cython_function_address("scipy.special.cython_special", "gammainc")
_gammainc_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl)
_gammainc_f8 = _gammainc_functype(_gammainc_addr)
# gammaincinv
_gammaincinv_addr = get_cython_function_address(
"scipy.special.cython_special", "gammaincinv"
)
_gammaincinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl)
_gammaincinv_f8 = _gammaincinv_functype(_gammaincinv_addr)
# erfinv
_erfinv_addr = get_cython_function_address(
"scipy.special.cython_special", "__pyx_fuse_0erfinv"
)
_erfinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl)
_erfinv_f8 = _erfinv_functype(_erfinv_addr)

which results in

NumbaWarning: Cannot cache compiled function "_betaln" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @numba_jit("f8(f8, f8)")
# etc

though now that I look at it carefully, it might be possible to avoid using scipy altogether here depending on what parts of math are supported by numba.

@benjeffery
Copy link
Member

The trick isn't the explicit signature, but passing the global as an argument to the function:

def _betaln(gamma_func, p, q)

The issue here is that numba needs to know the exact type signature for all arguments to a function. Using a global makes it an implicit argument of the function (technically a "cell" object in the function's closure). But unlike normal arguments, it's complicated for numba to watch for changes to the global's type.

Passing the global down from the nearest caller that is not a numba jit'ed should enable caching.

@jeromekelleher
Copy link
Member

Yeah, globals are to be avoided with numba for a few reasons.

@nspope
Copy link
Contributor Author

nspope commented Jan 7, 2025

Got it-- thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants