Skip to content

Commit

Permalink
Added integration test for model training.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Nov 27, 2023
1 parent 4c3cc8c commit 1db8682
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 9 deletions.
2 changes: 1 addition & 1 deletion mlspm/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_scan_window_from_comment(comment: str) -> np.ndarray:
return sw


def _rotate_and_stack(src: Iterable[dict], reverse: bool = True) -> Generator[dict, None, None]:
def _rotate_and_stack(src: Iterable[dict], reverse: bool = False) -> Generator[dict, None, None]:
"""
Take a sample in dict format and update it with fields containing an image stack, xyz coordinates and scan window.
Rotate the images to be xy-indexing convention and stack them into a single array.
Expand Down
9 changes: 1 addition & 8 deletions papers/ice_structure_discovery/training/fit_posnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,7 @@ def run(cfg):
lr_decay.step()

# Log losses
try:
loss_logger.add_train_loss(loss)
except ValueError as e:
torch.save(model.module.state_dict(), save_path := os.path.join(cfg['run_dir'], 'debug_model.pth'))
with open('debug_data', 'wb') as f:
pickle.dump((X.cpu().numpy(), ref.cpu().numpy()), f)
print(f'Save debug data on rank {cfg["global_rank"]}')
raise e
loss_logger.add_train_loss(loss)

if cfg['timings'] and cfg['global_rank'] == 0:
torch.cuda.synchronize()
Expand Down
262 changes: 262 additions & 0 deletions tests/integration_tests/test_train_posnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
#!/usr/bin/env python3

import os
import random
import shutil
import tarfile
from functools import partial
from pathlib import Path

import numpy as np
import torch
import webdataset as wds
from torch import nn, optim

import mlspm.data_loading as dl
import mlspm.preprocessing as pp
from mlspm import graph, utils
from mlspm.cli import parse_args
from mlspm.logging import LossLogPlot
from mlspm.models import PosNet

from PIL import Image


def make_model(device, cfg):
outsize = round((cfg["z_lims"][1] - cfg["z_lims"][0]) / cfg["box_res"][2]) + 1
model = PosNet(
encode_block_channels=[2, 4, 8, 16],
encode_block_depth=2,
decode_block_channels=[16, 8, 4],
decode_block_depth=1,
decode_block_channels2=[16, 8, 4],
decode_block_depth2=1,
attention_channels=[16, 16, 16],
res_connections=True,
activation="relu",
padding_mode="zeros",
pool_type="avg",
decoder_z_sizes=[5, 10, outsize],
z_outs=[3, 3, 5, 8],
peak_std=cfg["peak_std"],
).to(device)
criterion = nn.MSELoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=cfg["lr"])
lr_decay_rate = 1e-5
lr_decay = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + lr_decay_rate * b))
return model, criterion, optimizer, lr_decay


def make_test_data(cfg):
out_dir = Path(cfg["data_dir"])
out_dir.mkdir(exist_ok=True)
urls = wds.shardlists.expand_urls(cfg["urls_train"])
i_sample = 0
for url in urls:
temp_dir = Path(f"temp_{url}")
temp_dir.mkdir(exist_ok=True)
os.chdir(temp_dir)
with tarfile.open(url, "w") as f:
for _ in range(10):
afm = np.random.randint(0, 255, (64, 64, 8), dtype=np.uint8)
for i in range(afm.shape[-1]):
img_path = f"{i_sample}.{i}.png"
Image.fromarray(afm[:, ::-1, i].T).save(img_path)
f.add(img_path)
xyz = np.random.rand(8, 3)
xyz[:, :2] *= 8
atoms = np.concatenate([xyz, np.random.randint(1, 10, (8, 1))], axis=1)
xyz_path = f"{i_sample}.xyz"
utils.write_to_xyz(atoms, outfile=xyz_path, comment_str="Scan window: [[0.0 0.0 0.0], [8.0 8.0 1.0]]", verbose=0)
f.add(xyz_path)
i_sample += 1
os.chdir("..")
(temp_dir / url).rename(out_dir / url)
shutil.rmtree(temp_dir)


def apply_preprocessing(batch, cfg):
box_res = cfg["box_res"]
z_lims = cfg["z_lims"]
zmin = cfg["zmin"]
peak_std = cfg["peak_std"]

X, atoms, scan_windows = [batch[k] for k in ["X", "xyz", "sw"]]

nz_max = X[0].shape[-1]
nz = random.choice(range(1, nz_max + 1))
z0 = random.choice(range(0, min(5, nz_max + 1 - nz)))
X = [x[:, :, :, -nz:] for x in X] if z0 == 0 else [x[:, :, :, -(nz + z0) : -z0] for x in X]

atoms = [a[a[:, -1] != 29] for a in atoms]
pp.top_atom_to_zero(atoms)
xyz = atoms.copy()
mols = [graph.MoleculeGraph(a, []) for a in atoms]
mols, sw = graph.shift_mols_window(mols, scan_windows[0])

pp.rand_shift_xy_trend(X, max_layer_shift=0.02, max_total_shift=0.04)
box_borders = graph.make_box_borders(X[0].shape[1:3], res=box_res[:2], z_range=z_lims)
X, mols, box_borders = graph.add_rotation_reflection_graph(
X, mols, box_borders, num_rotations=1, reflections=True, crop=(32, 32), per_batch_item=True
)
pp.add_norm(X)
pp.add_gradient(X, c=0.3)
pp.add_noise(X, c=0.1, randomize_amplitude=True, normal_amplitude=True)
pp.add_cutout(X, n_holes=5)

mols = graph.threshold_atoms_bonds(mols, zmin)
ref = graph.make_position_distribution(mols, box_borders, box_res=box_res, std=peak_std)

return X, [ref], xyz, box_borders


def make_webDataloader(cfg):
shard_list = dl.ShardList(
cfg[f"urls_train"],
base_path=cfg["data_dir"],
substitute_param=True,
log=Path(cfg["run_dir"]) / "shards.log",
)

dataset = wds.WebDataset(shard_list)
dataset.pipeline.pop()
dataset.append(wds.tariterators.tarfile_to_samples())
dataset.append(wds.split_by_worker)
dataset.append(wds.decode("pill", dl.decode_xyz))
dataset.append(dl.rotate_and_stack())
dataset.append(dl.batched(cfg["batch_size"]))
dataset = dataset.map(partial(apply_preprocessing, cfg=cfg))

dataloader = wds.WebLoader(
dataset,
num_workers=cfg["num_workers"],
batch_size=None,
pin_memory=True,
collate_fn=dl.default_collate,
persistent_workers=False,
)

return dataset, dataloader


def batch_to_device(batch, device):
X, ref, *rest = batch
X = X[0].to(device)
ref = ref[0].to(device)
return X, ref, *rest


def run(cfg):
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create run directory
if not os.path.exists(cfg["run_dir"]):
os.makedirs(cfg["run_dir"])

# Define model, optimizer, and loss
model, criterion, optimizer, lr_decay = make_model(device, cfg)

# Setup checkpointing and load a checkpoint if available
checkpointer = utils.Checkpointer(
model,
optimizer,
additional_data={"lr_params": lr_decay},
checkpoint_dir=os.path.join(cfg["run_dir"], "Checkpoints/"),
keep_last_epoch=True,
)
init_epoch = checkpointer.epoch

# Setup logging
log_file = open(os.path.join(cfg["run_dir"], "batches.log"), "a")
loss_logger = LossLogPlot(
log_path=os.path.join(cfg["run_dir"], "loss_log.csv"),
plot_path=os.path.join(cfg["run_dir"], "loss_history.png"),
loss_labels=cfg["loss_labels"],
loss_weights=cfg["loss_weights"],
print_interval=cfg["print_interval"],
init_epoch=init_epoch,
stream=log_file,
)

for epoch in range(cfg["epochs"]):
# Create datasets and dataloaders
_, train_loader = make_webDataloader(cfg)
val_loader = train_loader

print(f"\n === Epoch {epoch}")

model.train()
for ib, batch in enumerate(train_loader):
# Transfer batch to device
X, ref, _, _ = batch_to_device(batch, device)

# Forward
pred, _ = model(X)
loss = criterion(pred, ref)

# Backward
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
lr_decay.step()

# Log losses
loss_logger.add_train_loss(loss)

print(f"Train batch {ib}")

# Validate

model.eval()
with torch.no_grad():
for ib, batch in enumerate(val_loader):
# Transfer batch to device
X, ref, _, _ = batch_to_device(batch, device)

# Forward
pred, _ = model(X)
loss = criterion(pred, ref)

loss_logger.add_val_loss(loss)

print(f"Val batch {ib}")

# Write average losses to log and report to terminal
loss_logger.next_epoch()

# Save checkpoint
checkpointer.next_epoch(loss_logger.val_losses[-1][0])

# Return to best epoch, and save model weights
checkpointer.revert_to_best_epoch()
print(f"Best validation loss on epoch {checkpointer.best_epoch}: {checkpointer.best_loss}")

log_file.close()
shutil.rmtree(cfg["run_dir"])
shutil.rmtree(cfg["data_dir"])


def test_train_posnet():
# fmt:off
cfg = parse_args(
[
"--run_dir", "test_train",
"--epochs", "2",
"--batch_size", "4",
"--z_lims", "-1.0", "0.5",
"--zmin", "-1.0",
"--data_dir", "./test_data",
"--urls_train", "train-K-{0..1}_{0..1}.tar",
"--box_res", "0.125", "0.125", "0.100",
"--peak_std", "0.20",
"--lr", "1e-4"
]
)
# fmt:on

make_test_data(cfg)
run(cfg)


if __name__ == "__main__":
test_train_posnet()

0 comments on commit 1db8682

Please sign in to comment.