-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ae1e832
commit 4d4b67a
Showing
2 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from os import PathLike | ||
from pathlib import Path | ||
from typing import List | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib import cm | ||
|
||
|
||
def make_prediction_plots( | ||
preds: List[np.ndarray] = None, | ||
true: List[np.ndarray] = None, | ||
losses: np.ndarray = None, | ||
descriptors: List[str] = None, | ||
outdir: PathLike = "./predictions/", | ||
start_ind: int = 0, | ||
verbose: bool = True, | ||
): | ||
""" | ||
Plot predictions/references for image descriptors. | ||
Arguments: | ||
preds: Predicted maps. Each list element corresponds to one descriptor and is an array of shape (batch_size, x_dim, y_dim). | ||
true: Reference maps. Each list element corresponds to one descriptor and is an array of shape (batch_size, x_dim, y_dim). | ||
losses: Losses for each prediction. Array of shape (len(preds), batch_size). | ||
descriptors: Names of descriptors. The name "ES" causes the coolwarm colormap to be used. | ||
outdir: Directory where images are saved. | ||
start_ind: Starting index for saved images. | ||
verbose: Whether to print output information. | ||
""" | ||
|
||
rows = (preds is not None) + (true is not None) | ||
if rows == 0: | ||
raise ValueError("preds and true cannot both be None.") | ||
elif rows == 1: | ||
data = preds if preds is not None else true | ||
else: | ||
assert len(preds) == len(true) | ||
|
||
cols = len(preds) if preds is not None else len(true) | ||
if descriptors is not None: | ||
assert len(descriptors) == cols | ||
|
||
outdir = Path(outdir) | ||
outdir.mkdir(exist_ok=True, parents=True) | ||
|
||
img_ind = start_ind | ||
batch_size = len(preds[0]) if preds is not None else len(true[0]) | ||
|
||
for j in range(batch_size): | ||
fig, axes = plt.subplots(rows, cols) | ||
fig.set_size_inches(6 * cols, 5 * rows) | ||
|
||
if rows == 1: | ||
axes = np.expand_dims(axes, axis=0) | ||
if cols == 1: | ||
axes = np.expand_dims(axes, axis=1) | ||
|
||
for i in range(cols): | ||
top_ax = axes[0, i] | ||
bottom_ax = axes[-1, i] | ||
|
||
if rows == 2: | ||
p = preds[i][j] | ||
t = true[i][j] | ||
vmax = np.concatenate([p, t]).max() | ||
vmin = np.concatenate([p, t]).min() | ||
else: | ||
d = data[i][j] | ||
vmax = d.max() | ||
vmin = d.min() | ||
|
||
title1 = "" | ||
title2 = "" | ||
cmap = cm.viridis | ||
if descriptors is not None: | ||
descriptor = descriptors[i] | ||
title1 += f"{descriptor} Prediction" | ||
title2 += f"{descriptor} Reference" | ||
if descriptor == "ES": | ||
vmax = max(abs(vmax), abs(vmin)) | ||
vmin = -vmax | ||
cmap = cm.coolwarm | ||
if losses is not None: | ||
title1 += f"\nMSE = {losses[i,j]:.2E}" | ||
if vmax == vmin == 0: | ||
vmin = 0 | ||
vmax = 0.1 | ||
if rows == 2: | ||
im1 = top_ax.imshow(p.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower") | ||
im2 = bottom_ax.imshow(t.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower") | ||
if title1: | ||
top_ax.set_title(title1) | ||
bottom_ax.set_title(title2) | ||
else: | ||
im1 = top_ax.imshow(d.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower") | ||
if title1: | ||
title = title1 if preds is not None else title2 | ||
top_ax.set_title(title) | ||
|
||
for axi in axes[:, i]: | ||
pos = axi.get_position() | ||
pos_new = [pos.x0, pos.y0, 0.8 * (pos.x1 - pos.x0), pos.y1 - pos.y0] | ||
axi.set_position(pos_new) | ||
|
||
pos1 = top_ax.get_position() | ||
pos2 = bottom_ax.get_position() | ||
c_pos = [pos1.x1 + 0.1 * (pos1.x1 - pos1.x0), pos2.y0, 0.08 * (pos1.x1 - pos1.x0), pos1.y1 - pos2.y0] | ||
cbar_ax = fig.add_axes(c_pos) | ||
fig.colorbar(im1, cax=cbar_ax) | ||
|
||
save_name = outdir / f"{img_ind}_pred.png" | ||
plt.savefig(save_name, bbox_inches="tight") | ||
plt.close() | ||
|
||
if verbose > 0: | ||
print(f"Prediction saved to {save_name}") | ||
img_ind += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from pathlib import Path | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import torch | ||
from mlspm.datasets import download_dataset | ||
|
||
from mlspm.models import ASDAFMNet | ||
from mlspm.visualization import make_input_plots, make_prediction_plots | ||
|
||
def plot_experiment_AFM(list_of_exp): | ||
cols = 10 | ||
rows = len(list_of_exp) | ||
fig = plt.figure(figsize=(1.8 * cols, 1.8 * rows)) | ||
ax = [] | ||
for i, experiment in enumerate(list_of_exp): | ||
filename_exp = data_dir / "experimental_configs_data/" / f"{experiment}orient_exp.npz" | ||
data = np.load(filename_exp) | ||
X_exp = data["X"] | ||
print("#" + str(experiment) + " 1S-Champhor experiment " + "X.shape:", X_exp.shape) | ||
|
||
for j in range(10): | ||
ax.append(fig.add_subplot(rows, cols, i * cols + j + 1)) | ||
xj = X_exp[0, :, :, j] | ||
vmax = xj.max() | ||
vmin = xj.min() | ||
plt.imshow(xj, cmap="afmhot", origin="lower", vmin=vmin - 0.1 * (vmax - vmin), vmax=vmax + 0.1 * (vmax - vmin)) | ||
plt.xticks([]) | ||
plt.yticks([]) | ||
if j == 0: | ||
ax[-1].set_ylabel("AFM experiment " + str(experiment)) | ||
|
||
plt.show() | ||
|
||
|
||
|
||
def plot_experiment_preds(list_of_exp): | ||
cols = len(list_of_exp) | ||
rows = 1 | ||
fig = plt.figure(figsize=(2 * cols, 2 * rows)) | ||
ax = [] | ||
for i, experiment in enumerate(list_of_exp): | ||
filename_exp = data_dir / "experimental_configs_data/" / f"{experiment}orient_exp.npz" | ||
data = np.load(filename_exp) | ||
Y_exp = data["Y"] | ||
ax.append(fig.add_subplot(rows, cols, i + 1)) | ||
plt.imshow(Y_exp[1][0], origin="lower") | ||
ax[-1].set_ylabel("vdW-Spheres") | ||
ax[-1].set_xlabel("AFM experiment " + str(experiment)) | ||
plt.xticks([]) | ||
plt.yticks([]) | ||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
# Input data path | ||
data_dir = Path("afm_camphor") | ||
|
||
# Output directory | ||
outdir = Path('predictions') | ||
outdir.mkdir(exist_ok=True) | ||
|
||
# Type of pretrained weights to load ('light' or 'heavy') | ||
weights = 'heavy' | ||
|
||
# Device to run inference on. Set to 'cuda' to use GPU-acceleration. | ||
device = 'cpu' | ||
|
||
# Descriptor labels for plotting | ||
descriptors = ['Atomic Disks', 'vdW Spheres', 'Height Map'] | ||
|
||
# Load model with pretrained weights | ||
model = ASDAFMNet(pretrained_weights=f'asdafm-{weights}').to(device) | ||
|
||
# Download AFM data | ||
download_dataset('AFM-camphor-exp', data_dir) | ||
|
||
for exp_num in [1, 3, 4, 6, 7]: | ||
|
||
# Load data | ||
X = np.load(data_dir / f'{exp_num}.npy') | ||
X = torch.from_numpy(X).float().unsqueeze(1).to(device) | ||
|
||
# Run prediction | ||
with torch.no_grad(): | ||
pred = model(X) | ||
pred = [p.cpu().numpy() for p in pred] | ||
|
||
# The input data here is saved in a transposed form (y, x). Transpose it to x, y order so that the plotting | ||
# utils work correctly. | ||
pred = [p.transpose(0, 2, 1) for p in pred] | ||
X = [x.squeeze(1).cpu().numpy().transpose(0, 2, 1, 3) for x in X] | ||
|
||
# Plot the inputs (AFM) and predictions (descriptors) | ||
make_input_plots(X, outdir=outdir, start_ind=exp_num) | ||
make_prediction_plots(pred, descriptors=descriptors, outdir=outdir, start_ind=exp_num) |