Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ solver seed tests, interrupt handler fixes #701

Merged
merged 4 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion gillespy2/solvers/cpp/c_base/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,24 @@
#include <vector>
#include <iostream>

#if defined(WIN32) || defined(WIN32) || defined(__WIN32) || defined(__WIN32__)
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) || defined(__WIN32__)
#include <windows.h>
#define GPY_PID_GET() ((int) GetCurrentProcessId())
#define GPY_INTERRUPT_HANDLER(handler_name, handler_code) \
static BOOL WINAPI handler_name(DWORD signum) { \
do handler_code while(0); \
return TRUE; \
}
#define GPY_INTERRUPT_INSTALL_HANDLER(handler) SetConsoleCtrlHandler(handler, TRUE)
#else
#include <unistd.h>
#include <csignal>
#define GPY_PID_GET() (getpid())
#define GPY_INTERRUPT_HANDLER(handler_name, handler_code) \
static void handler_name(int signum) { \
do handler_code while(0); \
}
#define GPY_INTERRUPT_INSTALL_HANDLER(handler) signal(SIGINT, handler)
#endif

namespace Gillespy
Expand Down
10 changes: 9 additions & 1 deletion gillespy2/solvers/cpp/c_base/ode_cpp_solver/ODESolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@

namespace Gillespy
{
static volatile bool interrupted = false;

GPY_INTERRUPT_HANDLER(signal_handler, {
interrupted = true;
})

static int f(realtype t, N_Vector y, N_Vector y_dot, void *user_data);

struct UserData
Expand All @@ -45,6 +51,8 @@ namespace Gillespy

void ODESolver(Simulation<double> *simulation, double increment)
{
GPY_INTERRUPT_INSTALL_HANDLER(signal_handler);

// CVODE constants are returned on every success or failure.
// CV_SUCCESS: Operation was successful.
// CV_MEM_NULL: CVODE memory block was not initialized with CVodeCreate.
Expand Down Expand Up @@ -114,7 +122,7 @@ namespace Gillespy
realtype tret = 0;

int current_time = 0;
for (tout = step_length; tout < end_time || cmpf(tout, end_time); tout += step_length)
for (tout = step_length; !interrupted && tout < end_time || cmpf(tout, end_time); tout += step_length)
{
// CV_NORMAL causes the solver to take internal steps until it has reached or just passed the `tout`
// parameter. The solver interpolates in order to return an approximate value of `y(tout)`.
Expand Down
27 changes: 3 additions & 24 deletions gillespy2/solvers/cpp/c_base/ssa_cpp_solver/SSASolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,19 @@

#include <cmath>
#include <random>
#include <csignal>
#include <string.h>

#ifdef _WIN32
#include <windows.h>
#undef max
#endif

#include "SSASolver.h"

namespace Gillespy
{
volatile bool interrupted = false;

#ifdef _WIN32
BOOL WINAPI eventHandler(DWORD CtrlType)
{
interrupted = true;
return TRUE;
}
#endif

void signalHandler(int signum)
{
GPY_INTERRUPT_HANDLER(signal_handler, {
interrupted = true;
}
})

void ssa_direct(Simulation<unsigned int> *simulation)
{
#ifdef _WIN32
SetConsoleCtrlHandler(eventHandler, TRUE);
#else
signal(SIGINT, signalHandler);
#endif
GPY_INTERRUPT_INSTALL_HANDLER(signal_handler);

if (simulation)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,21 @@

namespace Gillespy
{
namespace TauHybrid
{
static volatile bool interrupted = false;

bool interrupted = false;

void signalHandler(int signum)
{
interrupted = true;
}
GPY_INTERRUPT_HANDLER(signal_handler, {
interrupted = true;
})

namespace TauHybrid
{
void TauHybridCSolver(HybridSimulation *simulation, std::vector<Event> &events, const double tau_tol)
{
if (simulation == NULL)
{
return;
}
GPY_INTERRUPT_INSTALL_HANDLER(signal_handler);

Model<double> &model = *(simulation->model);
int num_species = model.number_species;
Expand Down Expand Up @@ -77,7 +76,7 @@ namespace Gillespy
TauArgs<double> tau_args = initialize(model, tau_tol);

// Simulate for each trajectory
for (int traj = 0; traj < num_trajectories; traj++)
for (int traj = 0; !interrupted && traj < num_trajectories; traj++)
{
if (traj > 0)
{
Expand Down Expand Up @@ -136,7 +135,7 @@ namespace Gillespy
// For now, a "guard" is put in place to prevent potentially infinite loops from occurring.
unsigned int integration_guard = 1000;

while (integration_guard > 0 && simulation->current_time < simulation->end_time)
while (!interrupted && integration_guard > 0 && simulation->current_time < simulation->end_time)
{
// Compute current propensity values based on existing state.
for (unsigned int rxn_i = 0; rxn_i < num_reactions; ++rxn_i)
Expand All @@ -156,6 +155,8 @@ namespace Gillespy
}
sol.data.propensities[rxn_i] = propensity;
}
if (interrupted)
break;

// Expected tau step is determined.
tau_step = select(
Expand Down Expand Up @@ -304,7 +305,10 @@ namespace Gillespy
current_populations[p_i] = (int) current_state[p_i];
}
}
} while (invalid_state);
} while (invalid_state && !interrupted);

if (interrupted)
break;

// Invalid state after the do-while loop implies that an unrecoverable error has occurred.
// While prior results are considered usable, the current integration results are not.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@

namespace Gillespy
{
bool interrupted = false;
static volatile bool interrupted = false;
std::mt19937_64 generator;

void signalHandler(int signum)
{
GPY_INTERRUPT_HANDLER(signal_handler, {
interrupted = true;
}
})

std::pair<std::map<std::string, int>, double> get_reactions(
const Gillespy::Model<unsigned int> *model,
Expand Down Expand Up @@ -78,7 +77,7 @@ namespace Gillespy

void tau_leaper(Gillespy::Simulation<unsigned int> *simulation, const double tau_tol)
{
signal(SIGINT, signalHandler);
GPY_INTERRUPT_INSTALL_HANDLER(signal_handler);

if (!simulation)
{
Expand Down
51 changes: 32 additions & 19 deletions test/test_all_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,23 @@
from gillespy2 import NumPySSASolver
from gillespy2 import TauLeapingSolver
from gillespy2 import TauHybridSolver
from gillespy2 import ODECSolver
from gillespy2 import TauLeapingCSolver
from gillespy2 import TauHybridCSolver


class TestAllSolvers(unittest.TestCase):

solvers = [SSACSolver, ODESolver, NumPySSASolver, TauLeapingSolver, TauHybridSolver]
solvers = [
SSACSolver,
ODESolver,
NumPySSASolver,
TauLeapingSolver,
TauHybridSolver,
ODECSolver,
TauLeapingCSolver,
TauHybridCSolver,
]

model = Example()
for sp in model.listOfSpecies.values():
Expand Down Expand Up @@ -65,15 +77,15 @@ def test_return_type_show_labels(self):
self.assertTrue(isinstance(self.labeled_results_more_trajectories[solver][0]['Sp'], np.ndarray))
self.assertTrue(isinstance(self.labeled_results_more_trajectories[solver][0]['Sp'][0], np.float))


def test_random_seed(self):
for solver in self.solvers:
same_results = self.model.run(solver=solver, seed=1)
compare_results = self.model.run(solver=solver,seed=1)
self.assertTrue(np.array_equal(same_results.to_array(), compare_results.to_array()))
if solver.name == 'ODESolver': continue
diff_results = self.model.run(solver=solver, seed=2)
self.assertFalse(np.array_equal(diff_results.to_array(),same_results.to_array()))
with self.subTest(solver=solver.name):
same_results = self.model.run(solver=solver, seed=1)
compare_results = self.model.run(solver=solver,seed=1)
self.assertTrue(np.array_equal(same_results.to_array(), compare_results.to_array()))
if solver.name in ["ODESolver", "ODECSolver"]: continue
diff_results = self.model.run(solver=solver, seed=2)
self.assertFalse(np.array_equal(diff_results.to_array(), same_results.to_array()))

def test_random_seed_unnamed_reactions(self):
model = self.model
Expand All @@ -82,25 +94,26 @@ def test_random_seed_unnamed_reactions(self):
unnamed_rxn = gillespy2.Reaction(reactants={}, products={'Sp':1}, rate=k2)
model.add_reaction(unnamed_rxn)
for solver in self.solvers:
same_results = self.model.run(solver=solver, seed=1)
compare_results = self.model.run(solver=solver,seed=1)
self.assertTrue(np.array_equal(same_results.to_array(), compare_results.to_array()))
if solver.name == 'ODESolver': continue
diff_results = self.model.run(solver=solver, seed=2)
self.assertFalse(np.array_equal(diff_results.to_array(),same_results.to_array()))
with self.subTest(solver=solver.name):
same_results = self.model.run(solver=solver, seed=1)
compare_results = self.model.run(solver=solver,seed=1)
self.assertTrue(np.array_equal(same_results.to_array(), compare_results.to_array()))
if solver.name in ["ODESolver", "ODECSolver"]: continue
diff_results = self.model.run(solver=solver, seed=2)
self.assertFalse(np.array_equal(diff_results.to_array(), same_results.to_array()))

def test_extraneous_args(self):
for solver in self.solvers:
with self.assertLogs(level='WARN'):
with self.subTest(solver=solver.name), self.assertLogs(level='WARN'):
model = Example()
results = model.run(solver=solver, nonsense='ABC')
model.run(solver=solver, nonsense='ABC')

def test_timeout(self):
for solver in self.solvers:
with self.assertLogs(level='WARN'):
with self.subTest(solver=solver.name), self.assertLogs(level='WARN'):
model = Oregonator()
model.timespan(np.linspace(0, 1000000, 101))
results = model.run(solver=solver, timeout=1)
model.timespan(np.linspace(0, 1000000, 1001))
model.run(solver=solver, timeout=0.1)

def test_basic_solver_import(self):
from gillespy2.solvers.numpy.basic_tau_leaping_solver import BasicTauLeapingSolver
Expand Down