Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow energy minimization maker to report energies #1004

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/atomate2/openmm/jobs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def make(

# Run the simulation
start = time.time()
self.run_openmm(sim)
self.run_openmm(sim, dir_name)
elapsed_time = time.time() - start

self._update_interchange(interchange, sim, prev_task)
Expand Down Expand Up @@ -303,6 +303,7 @@ def _add_reporters(
traj_file_name = self._resolve_attr("traj_file_name", prev_task)
traj_file_type = self._resolve_attr("traj_file_type", prev_task)
report_velocities = self._resolve_attr("report_velocities", prev_task)
wrap_traj = self._resolve_attr("wrap_traj", prev_task)

if has_steps & (traj_interval > 0):
writer_kwargs = {}
Expand All @@ -327,7 +328,7 @@ def _add_reporters(
kwargs = dict(
file=str(dir_name / f"{self.traj_file_name}.{traj_file_type}"),
reportInterval=traj_interval,
enforcePeriodicBox=self._resolve_attr("wrap_traj", prev_task),
enforcePeriodicBox=wrap_traj,
)
if report_velocities:
# assert package version
Expand Down Expand Up @@ -364,7 +365,7 @@ def _add_reporters(
)
sim.reporters.append(state_reporter)

def run_openmm(self, simulation: Simulation) -> NoReturn:
def run_openmm(self, sim: Simulation, dir_name: Path) -> NoReturn:
"""Abstract method for running the OpenMM simulation.

This method should be implemented by subclasses to
Expand Down
33 changes: 29 additions & 4 deletions src/atomate2/openmm/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

import numpy as np
from openmm import Integrator, LangevinMiddleIntegrator, MonteCarloBarostat
from openmm.app import StateDataReporter
from openmm.unit import atmosphere, kelvin, kilojoules_per_mole, nanometer, picoseconds

from atomate2.openmm.jobs.base import BaseOpenMMMaker
from atomate2.openmm.utils import create_list_summing_to

if TYPE_CHECKING:
from pathlib import Path

from emmet.core.openmm import OpenMMTaskDocument
from openmm.app import Simulation

Expand Down Expand Up @@ -41,7 +44,7 @@ class EnergyMinimizationMaker(BaseOpenMMMaker):
tolerance: float = 10
max_iterations: int = 0

def run_openmm(self, sim: Simulation) -> None:
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
"""Run the energy minimization with OpenMM.

This method performs energy minimization on the molecular system using
Expand All @@ -62,6 +65,28 @@ def run_openmm(self, sim: Simulation) -> None:
maxIterations=self.max_iterations,
)

if self.state_interval > 0:
state = sim.context.getState(
getPositions=True,
getVelocities=True,
getForces=True,
getEnergy=True,
enforcePeriodicBox=self.wrap_traj,
)

state_reporter = StateDataReporter(
file=f"{dir_name / self.state_file_name}.csv",
reportInterval=0,
step=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
)
state_reporter.report(sim, state)


@dataclass
class NPTMaker(BaseOpenMMMaker):
Expand All @@ -87,7 +112,7 @@ class NPTMaker(BaseOpenMMMaker):
pressure: float = 1
pressure_update_frequency: int = 10

def run_openmm(self, sim: Simulation) -> None:
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
"""Evolve the simulation for self.n_steps in the NPT ensemble.

This adds a Monte Carlo barostat to the system to put it into NPT, runs the
Expand Down Expand Up @@ -138,7 +163,7 @@ class NVTMaker(BaseOpenMMMaker):
name: str = "nvt simulation"
n_steps: int = 1_000_000

def run_openmm(self, sim: Simulation) -> None:
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
"""Evolve the simulation with OpenMM for self.n_steps.

Parameters
Expand Down Expand Up @@ -177,7 +202,7 @@ class TempChangeMaker(BaseOpenMMMaker):
temp_steps: int | None = None
starting_temperature: float | None = None

def run_openmm(self, sim: Simulation) -> None:
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
"""Evolve the simulation while gradually changing the temperature.

self.temperature is the final temperature. self.temp_steps
Expand Down
11 changes: 10 additions & 1 deletion tests/openmm_md/flows/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_flow_maker(interchange, run_job):
name="test_production",
tags=["test"],
makers=[
EnergyMinimizationMaker(max_iterations=1),
EnergyMinimizationMaker(max_iterations=1, state_interval=1),
NPTMaker(n_steps=5, pressure=1.0, state_interval=1, traj_interval=1),
OpenMMFlowMaker.anneal_flow(anneal_temp=400, final_temp=300, n_steps=5),
NVTMaker(n_steps=5),
Expand Down Expand Up @@ -157,6 +157,15 @@ def test_flow_maker(interchange, run_job):
calc_output = task_doc.calcs_reversed[0].output
assert len(calc_output.steps_reported) == 5

all_steps = [calc.output.steps_reported for calc in task_doc.calcs_reversed]
assert all_steps == [
[11, 12, 13, 14, 15],
[10],
[8, 9],
[6, 7],
[1, 2, 3, 4, 5],
[0],
]
# Test that the state interval is respected
assert calc_output.steps_reported == list(range(11, 16))
assert calc_output.traj_file == "trajectory5.dcd"
Expand Down
6 changes: 3 additions & 3 deletions tests/openmm_md/jobs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_make(interchange, tmp_path, run_job):

# monkey patch to allow running the test without openmm

def do_nothing(self, sim):
def do_nothing(self, sim, dir_name):
pass

BaseOpenMMMaker.run_openmm = do_nothing
Expand Down Expand Up @@ -170,7 +170,7 @@ def do_nothing(self, sim):

def test_make_w_velocities(interchange, run_job):
# monkey patch to allow running the test without openmm
def do_nothing(self, sim):
def do_nothing(self, sim, dir_name):
pass

BaseOpenMMMaker.run_openmm = do_nothing
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_make_from_prev(run_job):
maker = BaseOpenMMMaker(n_steps=10)

# monkey patch to allow running the test without openmm
def do_nothing(self, sim):
def do_nothing(self, sim, dir_name):
pass

BaseOpenMMMaker.run_openmm = do_nothing
Expand Down
3 changes: 3 additions & 0 deletions tests/openmm_md/jobs/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
from emmet.core.openmm import OpenMMInterchange
from openmm import XmlSerializer
Expand All @@ -23,6 +25,7 @@ def test_energy_minimization_maker(interchange, run_job):
new_positions = new_state.getPositions(asNumpy=True)

assert not np.all(new_positions == start_positions)
assert (Path(task_doc.calcs_reversed[0].output.dir_name) / "state.csv").exists()


def test_npt_maker(interchange, run_job):
Expand Down
2 changes: 1 addition & 1 deletion tests/openmm_md/jobs/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_make_from_prev(openmm_data, run_job):
maker = BaseOpenMMMaker(n_steps=10)

# monkey patch to allow running the test without openmm
def do_nothing(self, sim):
def do_nothing(self, sim, dir_name):
pass

BaseOpenMMMaker.run_openmm = do_nothing
Expand Down
Loading