Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

53 save model state customization #100

Merged
merged 20 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b1d2e25
Fixed some imports, corrected some typos, and added some comments.
Jul 25, 2024
5345109
Moved savestate parameter into config and renamed to checkpoint_period.
Jul 30, 2024
a1cc3ba
Ignoring temporary test forlders.
Jul 30, 2024
d3da61c
Storing config to both modelId.yaml and modelId_step.yaml
Jul 30, 2024
867c615
Added script to generate a version ID file.
Jul 30, 2024
2147dfa
Added code version file when saving model.
Jul 31, 2024
61e5ef2
Moved restart parameter from model initialization to run method.
Jul 31, 2024
4b74d01
Storing step count with code version on save.
Jul 31, 2024
43504f0
Added option to save model at predefined milestone steps.
Jul 31, 2024
4a2e338
Fixed test_model_init_savestate test to also check the existence of t…
Jul 31, 2024
370bf8f
Added test for saving milestone.
Jul 31, 2024
9bbaf3d
Added option to restart from milestone.
Jul 31, 2024
3add841
Added test for restarting from milestone.
Jul 31, 2024
994192c
Merge branch 'development' into 53_save_model_state_customization
vmgaribay Aug 1, 2024
60a84a6
Merge branch '53_save_model_state_customization' into merging_data_co…
Aug 1, 2024
458b24b
Moved test_data_collection_period and test_data_collection_list into …
Aug 1, 2024
c212708
Merge branch 'development' into 53_save_model_state_customization
Aug 5, 2024
11355ff
Merge branch 'development' into 53_save_model_state_customization
vanlankveldthijs Aug 5, 2024
a2fbd61
Merge branch 'development' into merging_data_collection_and_save_mode…
Aug 5, 2024
d7a7df3
Merge branch '53_save_model_state_customization' into merging_data_co…
Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dgl_ptm/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ env
env3
venv
venv3

# Testing
my_model/
test_model/
2 changes: 2 additions & 0 deletions dgl_ptm/dgl_ptm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class Config(BaseModel):
model_graph: object = None # TODO: might be possible to move it from config to model
step_count: int = 0
step_target: PositiveInt = 5
checkpoint_period: int = 10
milestones: Optional[List[PositiveInt]] = None
Comment on lines +216 to +217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add documentation/definitions for these two parameters somewhere?

steering_parameters: SteeringParams = SteeringParams()
alpha_dist: AlphaDist = AlphaDist()
capital_dist: CapitalDist = CapitalDist()
Expand Down
1 change: 1 addition & 0 deletions dgl_ptm/dgl_ptm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ initial_graph_args:
new_node_edges: 1
step_count: 0
step_target: 5
checkpoint_period: 10
cost_vals:
- 0.0
- 0.45
Expand Down
115 changes: 85 additions & 30 deletions dgl_ptm/dgl_ptm/model/initialize_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import copy
from pathlib import Path
import dgl
import torch
import pickle
import logging

from pathlib import Path
from dgl.data.utils import save_graphs, load_graphs

from dgl_ptm.network.network_creation import network_creation
from dgl_ptm.model.step import ptm_step
from dgl_ptm.agentInteraction.weight_update import weight_update
from dgl_ptm.model.data_collection import data_collection
from dgl.data.utils import save_graphs, load_graphs
from dgl_ptm.config import Config, CONFIG
from dgl_ptm.util.network_metrics import average_degree

Expand Down Expand Up @@ -57,7 +56,7 @@ def sample_distribution_tensor(type, distParameters, nSamples, round=False, deci
uniform_samples = torch.rand(size)
sample_ppf = torch.sqrt(torch.tensor(2.0)) * torch.erfinv(2 *(cdf_min + (cdf_max - cdf_min) * uniform_samples) - 1)

dist = destParameters[0] + destParameters[1] * sample_ppf
dist = distParameters[0] + distParameters[1] * sample_ppf

else:
raise NotImplementedError('Currently only uniform, normal, multinomial, and bernoulli distributions are supported')
Expand Down Expand Up @@ -135,21 +134,16 @@ class PovertyTrapModel(Model):

"""

def __init__(self,*, model_identifier, restart=False, savestate=10):
def __init__(self,*, model_identifier):
"""
restore from a savestate or create a PVT model instance.
restore from a checkpoint or create a PVT model instance.
Checks whether a model indentifier has been specified.

param: model_identifier: str, required. Identifier for the model. Used to save and load model states.
param: restart: boolean, optional. If True, the model is run from last
saved step. Default False.
param: savestate: int, optional. If provided, the model state is saved
on this frequency. Default is 10 i.e. every 10th time step.

"""

super().__init__(model_identifier = model_identifier)
self.restart = restart
self.savestate = savestate

# default values
self.device = CONFIG.device
Expand All @@ -169,8 +163,14 @@ def __init__(self,*, model_identifier, restart=False, savestate=10):
self.model_graph = CONFIG.model_graph
self.step_count = CONFIG.step_count
self.step_target = CONFIG.step_target
self.checkpoint_period = CONFIG.checkpoint_period
self.milestones = CONFIG.milestones
self.steering_parameters = CONFIG.steering_parameters

# Code version.
self.version = Path('version.md').read_text().splitlines()[0]
Comment on lines +170 to +171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of saving the software version along with the model output. but how practical is it to use git commit hash? It is not clear how this works in practice. For example, who runs the regen_version.sh file to store the commit hash, and how often is it run? Do the users use git to run experiments? Instead, we could use the software version specified in pyproject.toml as version = "0.1.0".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, I will add some guidance for this in the CONTRIBUTING.md file.

In short, the version is stored in a tracked file. This makes it independent of the repository: the version file can be shipped with the built code (outside the repo) and the code would still work. However, tracking the file means it cannot use the 'current commit hash', because committing this version file (with the current hash) will create a newer hash.
Instead, we should run the regen_version.md script before any release (or other 'important code version'). For example as part of the latest commit before tagging for release. It may be possible to do this using git actions, or we may have to do this manually, in which case it should be very clear from the CONTRIBUTING.md

I understand this may not be the ideal approach, but I could not find a better approach :(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also found that the relative path of Path('version.md') causes problems. Can you please make it absolute?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed code version to process version to reflect that this captures other aspects of the process as well.

Tbh, I don't really know how best to get the path of the package root (as opposed to the model save directory). Will discuss.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, I don't really know how best to get the path of the package root (as opposed to the model save directory). Will discuss.

parent_path = Path(__file__).resolve().parents, parents is a list, see doc



def set_model_parameters(self, *, parameterFilePath=None, **kwargs):
"""
Load or set model parameters
Expand Down Expand Up @@ -220,10 +220,12 @@ def set_model_parameters(self, *, parameterFilePath=None, **kwargs):
self.steering_parameters['npath'] = str(parent_dir / Path(cfg.steering_parameters.npath))
self.steering_parameters['epath'] = str(parent_dir / Path(cfg.steering_parameters.epath))

# save updated config to yaml file
# save updated config to yaml files
cfg_filename = parent_dir / f'{self._model_identifier}.yaml'
cfg.to_yaml(cfg_filename)
logger.warning(f'The model parameters are saved to {cfg_filename}.')
cfg_filename_step = parent_dir / f'{self._model_identifier}_{self.step_count}.yaml'
cfg.to_yaml(cfg_filename_step)
logger.warning(f'The model parameters are saved to {cfg_filename} and {cfg_filename_step}.')
Comment on lines +223 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear why the configurations (cfg) are stored twice, i.e. cfg_filename and cfg_filename_step. The function set_model_parameters is usually run once at the beginning when step_count is 0. Also, tests are not added in test_model.py::test_set_model_parameters to verify these changes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, I will

  • remove the part where it stores in cfg_filename, only storing in cfg_filename_step.
  • when continuing from a stored model, find the config file that has the highest step count that is not over the step where the process will continue.
  • add tests to test all this functionality

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also see if I can migrate the config parameters into a member of the PovertyTrapModel, as opposed to copying the values of the config.
Copying the values means we cannot just save the model parameters at will, because we might have lost some (non-copied) parameters. Migrating to a Config member of PovertyTrapModel would mean we can save the complete config at will (and I can move saving the config into a separate method).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of commit bfbb990 these features have been implemented, together with some tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add 'make unique' functionality for when restarting multiple runs from the same step.


def initialize_model(self):
"""
Expand Down Expand Up @@ -392,11 +394,27 @@ def step(self):
#TODO add model dump here. Also check against previous save to avoid overwriting
raise RuntimeError(f'execution of step failed for step {self.step_count}')

def run(self):
""" run the model for each step until the step_target is reached."""
def run(self, restart=False):
"""
run the model for each step until the step_target is reached.

param: restart: boolean or int or a pair of ints, optional.
If True, the model is run from last checkpoint,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to identify the "last checkpoint" by a "milestone" and "step" as (milestone, step), correct? This can be merged with the third case to simplify the code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, there was some confusion here. The tuple case would be (step, instance), not (milestone, step). I'll try to clarify this in the doc string.

if an int, the model is run from the first milestone at that step,
if a pair of ints, the model is run from that milestone at that step.
Comment on lines +403 to +404
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the first case is a special instance of the second case, where the first item is 1 i.e. restart =(1, int). can these two cases be merged to simplify the code?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, I will merge the int and tuple cases. The argument can then be a boolean (for running from a checkpoint) or a tuple (for running from a milestone).

Copy link
Author

@vanlankveldthijs vanlankveldthijs Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it again, it is slightly unclear what to do with the second value of the tuple. We could allow setting this to None, but then the desired milestone to load might either be

  • milestone 0 at the specified step count
  • the most recent milestone at the specified step count. I'm leaning towards this option.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add some context: the use case where the milestone with step and increment would be important, is if we want to compare different policy choices that run until the same step.
Imagine we have some starting state (let's say at step 500) and we want to compare policy choices A, B, and C, after applying them for the same amount of time (let's say for 500 steps). So we'd restart the model 3 times from a milestone at step 500 (each with different config parameters representing the different policies) and run until step 1000 3 times and store the final state. In this case, we must be able to store multiple milestones at the same time step (1000).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add logging message to state the milestone that will be used.
TODO: only keep the tuple case; remove the int case.

Default False.
"""

self.inputs = None
if isinstance(restart, bool):
if restart:
self.inputs = _load_model(f'./{self._model_identifier}')
elif isinstance(restart, int):
self.inputs = _load_model(f'./{self._model_identifier}/milestone_{restart}')
elif isinstance(restart, tuple):
self.inputs = _load_model(f'./{self._model_identifier}/milestone_{restart[0]}_{restart[1]}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find where in the code the model state is saved in the directory f'./{self._model_identifier}/milestone_{restart[0]}_{restart[1]}'. This is also not tested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our discussion, I will

  • add an mkdir command to _save_model so we have control over where any (milestone) state save creates a new directory.
  • add a test for the milestone_X_Y case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of commit e98f547 this has been implemented.


if self.restart:
self.inputs = _load_model(f'./{self._model_identifier}')
if self.inputs:
self.model_graph = copy.deepcopy(self.inputs["model_graph"])
#self.model_data = self.inputs["model_data"]
self.generator_state = self.inputs["generator_state"]
Expand All @@ -407,23 +425,42 @@ def run(self):
while self.step_count < self.step_target:
self.step()

# save the model state every step reported by savestate
if self.savestate and self.step_count % self.savestate == 0:
# save the model state every step reported by checkpoint_period and at specific milestones.
# checkpoint saves overwrite the previous checkpoint; milestone get unique folders.
save_checkpoint = 0 < self.checkpoint_period and self.step_count % self.checkpoint_period == 0
save_milestone = self.milestones and self.step_count in self.milestones
if save_checkpoint or save_milestone:
self.inputs = {
'model_graph': copy.deepcopy(self.model_graph),
#'model_data': copy.deepcopy(self.model_data),
'generator_state': generator.get_state(),
'step_count': self.step_count
'step_count': self.step_count,
'code_version': self.version
}
_save_model(f'./{self._model_identifier}', self.inputs)

# Note that a sinlge step could be both a checkpoint and a milestone.
# The checkpoint could be necessary to restore a crashed process while
# the milestone is required output.
if save_checkpoint:
_save_model(f'./{self._model_identifier}', self.inputs)
if save_milestone:
milestone_path = _make_path_unique(f'./{self._model_identifier}/milestone_{self.step_count}')
_save_model(milestone_path, self.inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find in the code where this directory milestone_path is actually created, and how the permission issue is handled.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated above, I will add an mkdir command to _save_model so we have control over where any (milestone) state save creates a new directory.


def _make_path_unique(path):
if Path(path).exists():
incr = 1
def add_incr(path, incr): return f'{path}_{incr}'
while Path(add_incr(path, incr)).exists(): incr += 1
path = add_incr(path, incr)
return path
Comment on lines +450 to +456
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a number to the path isnot informative. I suggest generating unique directories using timestamp as:

from datetime import datetime
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
unique_path = f"{path}_{timestamp}"

Copy link
Author

@vanlankveldthijs vanlankveldthijs Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this would make it more straightforward to make a unique path, I think this will also make it harder to identify the milestone when trying to load it. For example, if you're trying to continue from 'the second milestone' you would have to know (or look up) which timestamp this milestone has.

This might still be the best approach, or we could combine the concepts by storing milestones as {path}_{step}_{timestamp}.

Note that if we have any method for creating unique paths, we can always get the time at which that path was created from the OS.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: keep instance numbering instead of the timestamp
TODO: add docstring to _make_path_unique
TODO: rename incr to instance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add (str) description to config file. Not to be used by the program, but just for the researcher to describe the run or policy to test. Add comment to describe this value.


def _save_model(path, inputs):
""" save the model_graph, generator_state and model_data in files."""
""" save the model_graph, generator_state and code_version in files."""

# save the model_graph with a label
graph_label = {'step_count': torch.tensor([inputs["step_count"]])}
save_graphs(str(Path(path) / "model_graphs.bin"), inputs["model_graph"], graph_label)
graph_labels = {'step_count': torch.tensor([inputs["step_count"]])}
save_graphs(str(Path(path) / "model_graph.bin"), inputs["model_graph"], graph_labels)

# save the generator_state
with open(Path(path) / "generator_state.bin", 'wb') as file:
Expand All @@ -433,16 +470,20 @@ def _save_model(path, inputs):
#with open(Path(path) / "model_data.bin", 'wb') as file:
# pickle.dump([inputs["model_data"], inputs["step_count"]], file)

# save the code version
with open(Path(path) / "version.md", 'w') as file:
file.writelines([inputs["code_version"] + '\n', f'step={inputs["step_count"]}\n'])


def _load_model(path):
# Load model graphs
path_model_graph = Path(path) / "model_graphs.bin"
# Load model graph
path_model_graph = Path(path) / "model_graph.bin"
if not path_model_graph.is_file():
raise ValueError(f'The path {path_model_graph} is not a file.')

graph, graph_lebel = load_graphs(str(path_model_graph))
graph, graph_labels = load_graphs(str(path_model_graph))
graph = graph[0]
graph_step = graph_lebel['step_count'].tolist()[0]
graph_step = graph_labels['step_count'].tolist()[0]

# Load generator_state
path_generator_state = Path(path) / "generator_state.bin"
Expand All @@ -460,10 +501,23 @@ def _load_model(path):
#with open(path_model_data, 'rb') as file:
# data, data_step = pickle.load(file)

# Load code version
path_code_version = Path(path) / "version.md"
if not path_code_version.is_file():
raise ValueError(f'The path {path_code_version} is not a file.')

with open(path_code_version, 'r') as file:
code_version = file.readlines()[0]

# Check if graph_step, generator_step and data_step are the same
if graph_step != generator_step: #or graph_step != data_step:
msg = 'The step count in the model_graph and generator_state are not the same.'# and model_data are not the same.'
raise ValueError(msg)

# Check if the saved version and current code version are the same
version = Path('version.md').read_text().splitlines()[0]
if code_version != version:
logger.warning(f'Warning: loading model generated using earlier code version: {code_version}.')

# Show which step is loaded
logger.warning(f'Loading model state from step {generator_step}.')
Expand All @@ -472,6 +526,7 @@ def _load_model(path):
'model_graph': graph,
#'model_data': data,
'generator_state': generator,
'step_count': generator_step
'step_count': generator_step,
'code_version': code_version
}
return inputs
3 changes: 3 additions & 0 deletions dgl_ptm/regen_version.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some documentation in this file or in a md file in a docs folder?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current md file in the docs folder is about running on snellius. I think it's better to add some documentation to the CONTRIBUTING.md file.

git rev-parse HEAD > version.md

100 changes: 63 additions & 37 deletions dgl_ptm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,36 +47,6 @@ def test_ptm_step_timestep1(self, model):
model.step() # timestep 1
assert Path('my_model/edge_data/1.zarr').exists()


class TestDataCollection:
def test_data_collection(self, model):
data_collection(model.model_graph, timestep=0, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('my_model/agent_data.zarr').exists()
assert Path('my_model/edge_data/0.zarr').exists()

def test_data_collection_timestep1(self, model):
model.step() # timestep 0
data_collection(model.model_graph, timestep=1, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('my_model/agent_data.zarr').exists()
assert Path('my_model/edge_data/0.zarr').exists()
assert Path('my_model/edge_data/1.zarr').exists()

# check if dimension 'n_time' exist in agent_data.zarr
agent_data = xr.open_zarr('my_model/agent_data.zarr')
assert 'n_time' in agent_data.dims

# check variable names in edge_data/1.zarr
edge_data = xr.open_zarr('my_model/edge_data/1.zarr')
assert 'weight' in edge_data.variables

def test_data_collection_period(self, model):
if Path('my_model/edge_data/').exists():
shutil.rmtree('my_model/edge_data/')
Expand Down Expand Up @@ -135,6 +105,36 @@ def test_data_collection_period_and_list(self, model):
assert Path('my_model/edge_data/9.zarr').exists()


class TestDataCollection:
def test_data_collection(self, model):
data_collection(model.model_graph, timestep=0, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('my_model/agent_data.zarr').exists()
assert Path('my_model/edge_data/0.zarr').exists()

def test_data_collection_timestep1(self, model):
model.step() # timestep 0
data_collection(model.model_graph, timestep=1, npath = model.steering_parameters['npath'],
epath = model.steering_parameters['epath'], ndata = model.steering_parameters['ndata'],
edata = model.steering_parameters['edata'], format = model.steering_parameters['format'],
mode = model.steering_parameters['mode'])

assert Path('my_model/agent_data.zarr').exists()
assert Path('my_model/edge_data/0.zarr').exists()
assert Path('my_model/edge_data/1.zarr').exists()

# check if dimension 'n_time' exist in agent_data.zarr
agent_data = xr.open_zarr('my_model/agent_data.zarr')
assert 'n_time' in agent_data.dims

# check variable names in edge_data/1.zarr
edge_data = xr.open_zarr('my_model/edge_data/1.zarr')
assert 'weight' in edge_data.variables


class TestInitializeModel:
def test_set_model_parameters(self):
model = dgl_ptm.PovertyTrapModel(model_identifier='test_model')
Expand Down Expand Up @@ -226,31 +226,57 @@ def test_run(self, model):
assert model.model_graph.number_of_nodes() == 100

def test_model_init_savestate(self, model):
model.savestate = 1
model.checkpoint_period = 1
model.run()

assert model.inputs is not None
assert Path('my_model/model_graphs.bin').exists()
assert Path('my_model/model_graph.bin').exists()
assert Path('my_model/generator_state.bin').exists()
assert Path('my_model/version.md').exists()
assert model.inputs["step_count"] == 5

def test_model_init_savestate_not_default(self, model):
model.savestate = 2
model.checkpoint_period = 2
model.run()

assert model.inputs["step_count"] == 4

def test_model_init_restart(self, model):
model.savestate = 1
model.checkpoint_period = 1
model.step_target = 3 # only run the model till step 3
model.run()
expected_generator_state = set(model.inputs["generator_state"].tolist())

model.restart = True
model.step_target = 5 # contiune the model till step 5
model.run()
model.step_target = 5 # restart the model and run till step 5
model.run(restart=True)
stored_generator_state = set(model.inputs["generator_state"].tolist())

assert model.inputs is not None
assert model.inputs["step_count"] == 5
assert stored_generator_state == expected_generator_state

def test_model_milestone(self, model):
model.milestones = [2]
model.run()

assert model.inputs is not None
assert Path('my_model/milestone_2/model_graph.bin').exists()
assert Path('my_model/milestone_2/generator_state.bin').exists()
assert Path('my_model/milestone_2/version.md').exists()
assert model.inputs["step_count"] == 2

def test_model_milestone_restart(self, model):
model.milestones = [1]
model.step_target = 3 # only run the model till step 3
model.run()
expected_generator_state = set(model.inputs["generator_state"].tolist())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some checks here before restarting the model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which checks do you think are needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: check if the correct config file is there and whether the model is saved. Before the second run.

model.step_target = 5 # restart the model and run till step 5
model.run(restart=1)
stored_generator_state = set(model.inputs["generator_state"].tolist())

assert model.inputs is not None
assert model.inputs["step_count"] == 1
assert model.step_count == 5
assert stored_generator_state == expected_generator_state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are missing for the case isinstance(restart, tuple) e.g. model.run(restart=(2, 3)).

1 change: 1 addition & 0 deletions dgl_ptm/version.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
d3da61c8094ea16199e12f767b3ec8cfd1f4dae8
Loading