Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple JAX & C++ code generation #2615

Merged
merged 11 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]
else:
ModelModule = ModuleType
Expand Down
30 changes: 0 additions & 30 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
"""AMICI-generated module for model TPL_MODELNAME"""

import datetime
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import amici


if TYPE_CHECKING:
from amici.jax import JAXModel

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
raise amici.AmiciVersionError(
Expand Down Expand Up @@ -38,28 +32,4 @@
# when the model package is imported via `import`
TPL_MODELNAME._model_module = sys.modules[__name__]


def get_jax_model() -> "JAXModel":
# If the model directory was meanwhile overwritten, this would load the
# new version, which would not match the previously imported extension.
# This is not allowed, as it would lead to inconsistencies.
jax_py_file = Path(__file__).parent / "jax.py"
jax_py_file = jax_py_file.resolve()
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(jax_py_file)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Refusing to import {jax_py_file} which was changed since "
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
"between the different model implementations.\n"
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)
jax = amici._module_from_path("jax", jax_py_file)
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
163 changes: 20 additions & 143 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TYPE_CHECKING,
Literal,
)
from itertools import chain

import sympy as sp

Expand Down Expand Up @@ -56,7 +55,6 @@
AmiciCxxCodePrinter,
get_switch_statement,
)
from .jaxcodeprinter import AmiciJaxCodePrinter
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
Expand Down Expand Up @@ -146,10 +144,7 @@ class DEExporter:
If the given model uses special functions, this set contains hints for
model building.

:ivar _code_printer_jax:
Code printer to generate JAX code

:ivar _code_printer_cpp:
:ivar _code_printer:
Code printer to generate C++ code

:ivar generate_sensitivity_code:
Expand Down Expand Up @@ -218,15 +213,14 @@ def __init__(
self.set_name(model_name)
self.set_paths(outdir)

self._code_printer_cpp = AmiciCxxCodePrinter()
self._code_printer_jax = AmiciJaxCodePrinter()
self._code_printer = AmiciCxxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"]
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]

# Signatures and properties of generated model functions (see
# include/amici/model.h for details)
self.model: DEModel = de_model
self._code_printer_cpp.known_functions.update(
self._code_printer.known_functions.update(
splines.spline_user_functions(
self.model._splines, self._get_index("p")
)
Expand All @@ -249,7 +243,6 @@ def generate_model_code(self) -> None:
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_c_code()
self._generate_m_code()

Expand Down Expand Up @@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None:
if os.path.isfile(file_path):
os.remove(file_path)

@log_execution_time("generating jax code", logger)
def _generate_jax_code(self) -> None:
try:
from amici.jax.model import JAXModel
except ImportError:
logger.warning(
"Could not import JAXModel. JAX code will not be generated."
)
return

eq_names = (
"xdot",
"w",
"x0",
"y",
"sigmay",
"Jy",
"x_solver",
"x_rdata",
"total_cl",
)
sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata")

indent = 8

def jnp_array_str(array) -> str:
elems = ", ".join(str(s) for s in array)

return f"jnp.array([{elems}])"

# replaces Heaviside variables with corresponding functions
subs_heaviside = dict(
zip(
self.model.sym("h"),
[sp.Heaviside(x) for x in self.model.eq("root")],
strict=True,
)
)
# replaces observables with a generic my variable
subs_observables = dict(
zip(
self.model.sym("my"),
[sp.Symbol("my")] * len(self.model.sym("my")),
strict=True,
)
)

tpl_data = {
# assign named variable using corresponding algebraic formula (function body)
**{
f"{eq_name.upper()}_EQ": "\n".join(
self._code_printer_jax._get_sym_lines(
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
self.model.eq(eq_name).subs(
{**subs_heaviside, **subs_observables}
),
indent,
)
)[indent:] # remove indent for first line
for eq_name in eq_names
},
# create jax array from concatenation of named variables
**{
f"{eq_name.upper()}_RET": jnp_array_str(
strip_pysb(s) for s in self.model.sym(eq_name)
)
if self.model.sym(eq_name)
else "jnp.array([])"
for eq_name in eq_names
},
# assign named variables from a jax array
**{
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "_"
for sym_name in sym_names
},
# tuple of variable names (ids as they are unique)
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "tuple()"
for sym_name in ("p", "k", "y", "x")
},
**{
# in jax model we do not need to distinguish between p (parameters) and
# k (fixed parameters) so we use a single variable combining both
"PK_SYMS": "".join(
str(strip_pysb(s)) + ", "
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"PK_IDS": "".join(
f'"{strip_pysb(s)}", '
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
os.makedirs(
os.path.join(self.model_path, self.model_name), exist_ok=True
)

apply_template(
os.path.join(amiciModulePath, "jax.template.py"),
os.path.join(self.model_path, self.model_name, "jax.py"),
tpl_data,
)

def _generate_c_code(self) -> None:
"""
Create C++ code files for the model based on
Expand Down Expand Up @@ -795,7 +673,7 @@ def _get_function_body(
lines = []

if len(equations) == 0 or (
isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix))
isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix)
and min(equations.shape) == 0
):
# dJydy is a list
Expand Down Expand Up @@ -852,7 +730,7 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};",
f"{self._code_printer.doprint(formula)};",
]
)
cases[ipar] = expressions
Expand All @@ -867,12 +745,12 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())\n "
f"{function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};"
f"{self._code_printer.doprint(formula)};"
)

elif function in event_functions:
cases = {
ie: self._code_printer_cpp._get_sym_lines_array(
ie: self._code_printer._get_sym_lines_array(
equations[ie], function, 0
)
for ie in range(self.model.num_events())
Expand All @@ -885,7 +763,7 @@ def _get_function_body(
for ie, inner_equations in enumerate(equations):
inner_lines = []
inner_cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
inner_equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -900,7 +778,7 @@ def _get_function_body(
and equations.shape[1] == self.model.num_par()
):
cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -910,15 +788,15 @@ def _get_function_body(
elif function in multiobs_functions:
if function == "dJydy":
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[iobs], function, 0
)
for iobs in range(self.model.num_obs())
if not smart_is_zero_matrix(equations[iobs])
}
else:
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[:, iobs], function, 0
)
for iobs in range(equations.shape[1])
Expand Down Expand Up @@ -948,7 +826,7 @@ def _get_function_body(
tmp_equations = sp.Matrix(
[equations[i] for i in static_idxs]
)
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -974,7 +852,7 @@ def _get_function_body(
[equations[i] for i in dynamic_idxs]
)

tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -986,12 +864,12 @@ def _get_function_body(
lines.extend(tmp_lines)

else:
lines += self._code_printer_cpp._get_sym_lines_symbols(
lines += self._code_printer._get_sym_lines_symbols(
symbols, equations, function, 4
)

else:
lines += self._code_printer_cpp._get_sym_lines_array(
lines += self._code_printer._get_sym_lines_array(
equations, function, 4
)

Expand Down Expand Up @@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")),
"NDJYDY": "std::vector<int>{%s}"
% ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")),
"NDJYDY": f"std::vector<int>{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}",
"NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")),
"NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")),
"NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")),
Expand All @@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None:
"NK": self.model.num_const(),
"O2MODE": "amici::SecondOrderMode::none",
# using code printer ensures proper handling of nan/inf
"PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
1:-1
],
"FIXED_PARAMETERS": self._code_printer_cpp.doprint(
"FIXED_PARAMETERS": self._code_printer.doprint(
self.model.val("k")
)[1:-1],
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
Expand Down Expand Up @@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]'
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
Loading
Loading