Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple JAX & C++ code generation #2615

Merged
merged 11 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@
"amici_model = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False,\n",
" compile_=True,\n",
" jax=False, # load the amici model this time\n",
")\n",
"\n",
Expand Down
163 changes: 20 additions & 143 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TYPE_CHECKING,
Literal,
)
from itertools import chain

import sympy as sp

Expand Down Expand Up @@ -56,7 +55,6 @@
AmiciCxxCodePrinter,
get_switch_statement,
)
from .jaxcodeprinter import AmiciJaxCodePrinter
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
Expand Down Expand Up @@ -146,10 +144,7 @@ class DEExporter:
If the given model uses special functions, this set contains hints for
model building.

:ivar _code_printer_jax:
Code printer to generate JAX code

:ivar _code_printer_cpp:
:ivar _code_printer:
Code printer to generate C++ code

:ivar generate_sensitivity_code:
Expand Down Expand Up @@ -218,15 +213,14 @@ def __init__(
self.set_name(model_name)
self.set_paths(outdir)

self._code_printer_cpp = AmiciCxxCodePrinter()
self._code_printer_jax = AmiciJaxCodePrinter()
self._code_printer = AmiciCxxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"]
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]

# Signatures and properties of generated model functions (see
# include/amici/model.h for details)
self.model: DEModel = de_model
self._code_printer_cpp.known_functions.update(
self._code_printer.known_functions.update(
splines.spline_user_functions(
self.model._splines, self._get_index("p")
)
Expand All @@ -249,7 +243,6 @@ def generate_model_code(self) -> None:
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_c_code()
self._generate_m_code()

Expand Down Expand Up @@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None:
if os.path.isfile(file_path):
os.remove(file_path)

@log_execution_time("generating jax code", logger)
def _generate_jax_code(self) -> None:
try:
from amici.jax.model import JAXModel
except ImportError:
logger.warning(
"Could not import JAXModel. JAX code will not be generated."
)
return

eq_names = (
"xdot",
"w",
"x0",
"y",
"sigmay",
"Jy",
"x_solver",
"x_rdata",
"total_cl",
)
sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata")

indent = 8

def jnp_array_str(array) -> str:
elems = ", ".join(str(s) for s in array)

return f"jnp.array([{elems}])"

# replaces Heaviside variables with corresponding functions
subs_heaviside = dict(
zip(
self.model.sym("h"),
[sp.Heaviside(x) for x in self.model.eq("root")],
strict=True,
)
)
# replaces observables with a generic my variable
subs_observables = dict(
zip(
self.model.sym("my"),
[sp.Symbol("my")] * len(self.model.sym("my")),
strict=True,
)
)

tpl_data = {
# assign named variable using corresponding algebraic formula (function body)
**{
f"{eq_name.upper()}_EQ": "\n".join(
self._code_printer_jax._get_sym_lines(
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
self.model.eq(eq_name).subs(
{**subs_heaviside, **subs_observables}
),
indent,
)
)[indent:] # remove indent for first line
for eq_name in eq_names
},
# create jax array from concatenation of named variables
**{
f"{eq_name.upper()}_RET": jnp_array_str(
strip_pysb(s) for s in self.model.sym(eq_name)
)
if self.model.sym(eq_name)
else "jnp.array([])"
for eq_name in eq_names
},
# assign named variables from a jax array
**{
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "_"
for sym_name in sym_names
},
# tuple of variable names (ids as they are unique)
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "tuple()"
for sym_name in ("p", "k", "y", "x")
},
**{
# in jax model we do not need to distinguish between p (parameters) and
# k (fixed parameters) so we use a single variable combining both
"PK_SYMS": "".join(
str(strip_pysb(s)) + ", "
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"PK_IDS": "".join(
f'"{strip_pysb(s)}", '
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
os.makedirs(
os.path.join(self.model_path, self.model_name), exist_ok=True
)

apply_template(
os.path.join(amiciModulePath, "jax.template.py"),
os.path.join(self.model_path, self.model_name, "jax.py"),
tpl_data,
)

def _generate_c_code(self) -> None:
"""
Create C++ code files for the model based on
Expand Down Expand Up @@ -795,7 +673,7 @@ def _get_function_body(
lines = []

if len(equations) == 0 or (
isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix))
isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix)
and min(equations.shape) == 0
):
# dJydy is a list
Expand Down Expand Up @@ -852,7 +730,7 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};",
f"{self._code_printer.doprint(formula)};",
]
)
cases[ipar] = expressions
Expand All @@ -867,12 +745,12 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())\n "
f"{function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};"
f"{self._code_printer.doprint(formula)};"
)

elif function in event_functions:
cases = {
ie: self._code_printer_cpp._get_sym_lines_array(
ie: self._code_printer._get_sym_lines_array(
equations[ie], function, 0
)
for ie in range(self.model.num_events())
Expand All @@ -885,7 +763,7 @@ def _get_function_body(
for ie, inner_equations in enumerate(equations):
inner_lines = []
inner_cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
inner_equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -900,7 +778,7 @@ def _get_function_body(
and equations.shape[1] == self.model.num_par()
):
cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -910,15 +788,15 @@ def _get_function_body(
elif function in multiobs_functions:
if function == "dJydy":
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[iobs], function, 0
)
for iobs in range(self.model.num_obs())
if not smart_is_zero_matrix(equations[iobs])
}
else:
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[:, iobs], function, 0
)
for iobs in range(equations.shape[1])
Expand Down Expand Up @@ -948,7 +826,7 @@ def _get_function_body(
tmp_equations = sp.Matrix(
[equations[i] for i in static_idxs]
)
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -974,7 +852,7 @@ def _get_function_body(
[equations[i] for i in dynamic_idxs]
)

tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -986,12 +864,12 @@ def _get_function_body(
lines.extend(tmp_lines)

else:
lines += self._code_printer_cpp._get_sym_lines_symbols(
lines += self._code_printer._get_sym_lines_symbols(
symbols, equations, function, 4
)

else:
lines += self._code_printer_cpp._get_sym_lines_array(
lines += self._code_printer._get_sym_lines_array(
equations, function, 4
)

Expand Down Expand Up @@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")),
"NDJYDY": "std::vector<int>{%s}"
% ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")),
"NDJYDY": f"std::vector<int>{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}",
"NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")),
"NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")),
"NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")),
Expand All @@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None:
"NK": self.model.num_const(),
"O2MODE": "amici::SecondOrderMode::none",
# using code printer ensures proper handling of nan/inf
"PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
1:-1
],
"FIXED_PARAMETERS": self._code_printer_cpp.doprint(
"FIXED_PARAMETERS": self._code_printer.doprint(
self.model.val("k")
)[1:-1],
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
Expand Down Expand Up @@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]'
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
Loading
Loading