Skip to content

Commit

Permalink
Make sympy code optimizations optional (#1878)
Browse files Browse the repository at this point in the history
Turns out that applying sympy optimizations as introduced
in #1377 is
potentially very costly. Therefore, they are disabled by
default. They can be enabled by setting
`AmiciCxxCodePrinter.optimizations` as shown in the test
case.
  • Loading branch information
dweindl authored Oct 28, 2022
1 parent e43d0e8 commit 55299c9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
27 changes: 21 additions & 6 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@
import itertools
import os
import re
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Iterable

import sympy as sp
from sympy.codegen.rewriting import optimize, optims_c99
from sympy.codegen.rewriting import optimize, Optimization
from sympy.printing.cxx import CXX11CodePrinter
from sympy.utilities.iterables import numbered_symbols
from toposort import toposort


class AmiciCxxCodePrinter(CXX11CodePrinter):
"""C++ code printer"""
"""
C++ code printer
Attributes
----------
extract_cse:
Whether to extract common subexpression during code printing.
Currently controlled by environment variable ``AMICI_EXTRACT_CSE``.
optimizations:
Iterable of :class:`sympy.codegen.rewriting.Optimization`s to optimize
generated code (e.g. :data:`sympy.codegen.rewriting.Optimization` for
optimizations, such as ``log(1 + x)`` --> ``logp1(x)``).
Applying these optimizations is potentially quite costly.
"""

optimizations: Iterable[Optimization] = ()

def __init__(self, optimize_code: bool = True):
def __init__(self):
"""
Create code printer
Expand All @@ -29,8 +44,8 @@ def __init__(self, optimize_code: bool = True):

# Floating-point optimizations
# e.g., log(1 + x) --> logp1(x)
if optimize_code:
self._fpoptimizer = lambda x: optimize(x, optims_c99)
if self.optimizations:
self._fpoptimizer = lambda x: optimize(x, self.optimizations)
else:
self._fpoptimizer = None

Expand Down
14 changes: 14 additions & 0 deletions python/tests/test_cxxcodeprinter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from amici.cxxcodeprinter import AmiciCxxCodePrinter
from sympy.codegen.rewriting import optims_c99
import sympy as sp


def test_optimizations():
"""Check that AmiciCxxCodePrinter handles optimizations correctly."""
try:
old_optim = AmiciCxxCodePrinter.optimizations
AmiciCxxCodePrinter.optimizations = optims_c99
cp = AmiciCxxCodePrinter()
assert "expm1" in cp.doprint(sp.sympify("exp(x) - 1"))
finally:
AmiciCxxCodePrinter.optimizations = old_optim

0 comments on commit 55299c9

Please sign in to comment.