Skip to content

Commit

Permalink
Added prediction script for camphor
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Jan 24, 2024
1 parent ae1e832 commit 4d4b67a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
117 changes: 117 additions & 0 deletions mlspm/image/_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from os import PathLike
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


def make_prediction_plots(
preds: List[np.ndarray] = None,
true: List[np.ndarray] = None,
losses: np.ndarray = None,
descriptors: List[str] = None,
outdir: PathLike = "./predictions/",
start_ind: int = 0,
verbose: bool = True,
):
"""
Plot predictions/references for image descriptors.
Arguments:
preds: Predicted maps. Each list element corresponds to one descriptor and is an array of shape (batch_size, x_dim, y_dim).
true: Reference maps. Each list element corresponds to one descriptor and is an array of shape (batch_size, x_dim, y_dim).
losses: Losses for each prediction. Array of shape (len(preds), batch_size).
descriptors: Names of descriptors. The name "ES" causes the coolwarm colormap to be used.
outdir: Directory where images are saved.
start_ind: Starting index for saved images.
verbose: Whether to print output information.
"""

rows = (preds is not None) + (true is not None)
if rows == 0:
raise ValueError("preds and true cannot both be None.")
elif rows == 1:
data = preds if preds is not None else true
else:
assert len(preds) == len(true)

cols = len(preds) if preds is not None else len(true)
if descriptors is not None:
assert len(descriptors) == cols

outdir = Path(outdir)
outdir.mkdir(exist_ok=True, parents=True)

img_ind = start_ind
batch_size = len(preds[0]) if preds is not None else len(true[0])

for j in range(batch_size):
fig, axes = plt.subplots(rows, cols)
fig.set_size_inches(6 * cols, 5 * rows)

if rows == 1:
axes = np.expand_dims(axes, axis=0)
if cols == 1:
axes = np.expand_dims(axes, axis=1)

for i in range(cols):
top_ax = axes[0, i]
bottom_ax = axes[-1, i]

if rows == 2:
p = preds[i][j]
t = true[i][j]
vmax = np.concatenate([p, t]).max()
vmin = np.concatenate([p, t]).min()
else:
d = data[i][j]
vmax = d.max()
vmin = d.min()

title1 = ""
title2 = ""
cmap = cm.viridis
if descriptors is not None:
descriptor = descriptors[i]
title1 += f"{descriptor} Prediction"
title2 += f"{descriptor} Reference"
if descriptor == "ES":
vmax = max(abs(vmax), abs(vmin))
vmin = -vmax
cmap = cm.coolwarm
if losses is not None:
title1 += f"\nMSE = {losses[i,j]:.2E}"
if vmax == vmin == 0:
vmin = 0
vmax = 0.1
if rows == 2:
im1 = top_ax.imshow(p.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower")
im2 = bottom_ax.imshow(t.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower")
if title1:
top_ax.set_title(title1)
bottom_ax.set_title(title2)
else:
im1 = top_ax.imshow(d.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower")
if title1:
title = title1 if preds is not None else title2
top_ax.set_title(title)

for axi in axes[:, i]:
pos = axi.get_position()
pos_new = [pos.x0, pos.y0, 0.8 * (pos.x1 - pos.x0), pos.y1 - pos.y0]
axi.set_position(pos_new)

pos1 = top_ax.get_position()
pos2 = bottom_ax.get_position()
c_pos = [pos1.x1 + 0.1 * (pos1.x1 - pos1.x0), pos2.y0, 0.08 * (pos1.x1 - pos1.x0), pos1.y1 - pos2.y0]
cbar_ax = fig.add_axes(c_pos)
fig.colorbar(im1, cax=cbar_ax)

save_name = outdir / f"{img_ind}_pred.png"
plt.savefig(save_name, bbox_inches="tight")
plt.close()

if verbose > 0:
print(f"Prediction saved to {save_name}")
img_ind += 1
98 changes: 98 additions & 0 deletions papers/asd-afm/predict_1S-camphor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
from mlspm.datasets import download_dataset

from mlspm.models import ASDAFMNet
from mlspm.visualization import make_input_plots, make_prediction_plots

def plot_experiment_AFM(list_of_exp):
cols = 10
rows = len(list_of_exp)
fig = plt.figure(figsize=(1.8 * cols, 1.8 * rows))
ax = []
for i, experiment in enumerate(list_of_exp):
filename_exp = data_dir / "experimental_configs_data/" / f"{experiment}orient_exp.npz"
data = np.load(filename_exp)
X_exp = data["X"]
print("#" + str(experiment) + " 1S-Champhor experiment " + "X.shape:", X_exp.shape)

for j in range(10):
ax.append(fig.add_subplot(rows, cols, i * cols + j + 1))
xj = X_exp[0, :, :, j]
vmax = xj.max()
vmin = xj.min()
plt.imshow(xj, cmap="afmhot", origin="lower", vmin=vmin - 0.1 * (vmax - vmin), vmax=vmax + 0.1 * (vmax - vmin))
plt.xticks([])
plt.yticks([])
if j == 0:
ax[-1].set_ylabel("AFM experiment " + str(experiment))

plt.show()



def plot_experiment_preds(list_of_exp):
cols = len(list_of_exp)
rows = 1
fig = plt.figure(figsize=(2 * cols, 2 * rows))
ax = []
for i, experiment in enumerate(list_of_exp):
filename_exp = data_dir / "experimental_configs_data/" / f"{experiment}orient_exp.npz"
data = np.load(filename_exp)
Y_exp = data["Y"]
ax.append(fig.add_subplot(rows, cols, i + 1))
plt.imshow(Y_exp[1][0], origin="lower")
ax[-1].set_ylabel("vdW-Spheres")
ax[-1].set_xlabel("AFM experiment " + str(experiment))
plt.xticks([])
plt.yticks([])
plt.show()


if __name__ == '__main__':

# Input data path
data_dir = Path("afm_camphor")

# Output directory
outdir = Path('predictions')
outdir.mkdir(exist_ok=True)

# Type of pretrained weights to load ('light' or 'heavy')
weights = 'heavy'

# Device to run inference on. Set to 'cuda' to use GPU-acceleration.
device = 'cpu'

# Descriptor labels for plotting
descriptors = ['Atomic Disks', 'vdW Spheres', 'Height Map']

# Load model with pretrained weights
model = ASDAFMNet(pretrained_weights=f'asdafm-{weights}').to(device)

# Download AFM data
download_dataset('AFM-camphor-exp', data_dir)

for exp_num in [1, 3, 4, 6, 7]:

# Load data
X = np.load(data_dir / f'{exp_num}.npy')
X = torch.from_numpy(X).float().unsqueeze(1).to(device)

# Run prediction
with torch.no_grad():
pred = model(X)
pred = [p.cpu().numpy() for p in pred]

# The input data here is saved in a transposed form (y, x). Transpose it to x, y order so that the plotting
# utils work correctly.
pred = [p.transpose(0, 2, 1) for p in pred]
X = [x.squeeze(1).cpu().numpy().transpose(0, 2, 1, 3) for x in X]

# Plot the inputs (AFM) and predictions (descriptors)
make_input_plots(X, outdir=outdir, start_ind=exp_num)
make_prediction_plots(pred, descriptors=descriptors, outdir=outdir, start_ind=exp_num)

0 comments on commit 4d4b67a

Please sign in to comment.