Skip to content

Commit

Permalink
register jitable prism functions for numba
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Oct 27, 2023
1 parent 925a3a4 commit d216370
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 58 deletions.
58 changes: 0 additions & 58 deletions geoana/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,3 @@
prism_fxxz,
prism_fxyz,
)

try:
from numba.extending import get_cython_function_address
import ctypes

def __as_ctypes_func(module, function, argument_types):
func_address = get_cython_function_address(module, function)
func_type = ctypes.CFUNCTYPE(*argument_types)
return func_type(func_address)

c_prism = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_f',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fz = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fz',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fzz = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fzz',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fzx = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fzx',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fzy = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fzy',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fzzz = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fzzz',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fxxy = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fxxy',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fxxz = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fxxz',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)
c_prism_fxyz = __as_ctypes_func(
'geoana.kernels._extensions.potential_field_prism',
'prism_fxyz',
(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
)

except ImportError:
pass
54 changes: 54 additions & 0 deletions geoana/kernels/_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
try:
# register numba jitable versions of the prism functions
# if numba is available (and this module is installed).
from numba.extending import (
overload,
get_cython_function_address
)
from numba import types
import ctypes

from .potential_field_prism import (
prism_f,
prism_fz,
prism_fzz,
prism_fzx,
prism_fzy,
prism_fzzz,
prism_fxxy,
prism_fxxz,
prism_fxyz,
)
funcs = [
prism_f,
prism_fz,
prism_fzz,
prism_fzx,
prism_fzy,
prism_fzzz,
prism_fxxy,
prism_fxxz,
prism_fxyz,
]

def _numba_register_prism_func(prism_func):
module = 'geoana.kernels._extensions.potential_field_prism'
name = prism_func.__name__

func_address = get_cython_function_address(module, name)
func_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
c_func = func_type(func_address)

@overload(prism_func)
def numba_func(x, y, z):
if isinstance(x, types.Float):
if isinstance(y, types.Float):
if isinstance(z, types.Float):
def f(x, y, z):
return c_func(x, y, z)
return f
for func in funcs:
_numba_register_prism_func(func)

except ImportError as err:
pass

0 comments on commit d216370

Please sign in to comment.