Skip to content

Commit

Permalink
Infer channels from z size in AttentionUnet + test + some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Feb 7, 2024
1 parent 5d0a2f0 commit 31789b6
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ dmypy.json

# Other
molecules
*.csv
48 changes: 30 additions & 18 deletions mlspm/image/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from os import PathLike
from typing import List, Literal, Optional, Tuple

import torch
Expand All @@ -21,11 +20,10 @@ class AttentionUNet(nn.Module):
by concatenating along channel axis.
Arguments:
conv3d_in_channels: Number of channels in input.
conv2d_in_channels: Number of channels in first 2D conv layer after flattening 3D to 2D.
conv3d_out_channels: Number of channels after 3D-to-2D flattening after each 3D conv block. Depends on input z size.
z_in: Size of input array in the z-dimension.
n_in: Number of input 3D images.
n_out: Number of output 2D maps.
in_channels: Number of channels in input array.
merge_block_channels: Number of channels in input merging 3D conv blocks.
merge_block_depth: Number of layers in each merge conv block.
conv3d_block_channels: Number channels in 3D conv blocks.
Expand Down Expand Up @@ -54,11 +52,10 @@ class AttentionUNet(nn.Module):

def __init__(
self,
conv3d_in_channels: int,
conv2d_in_channels: int,
conv3d_out_channels: List[int],
z_in: int = 10,
n_in: int = 1,
n_out: int = 3,
in_channels: int = 1,
merge_block_channels: List[int] = [8],
merge_block_depth: int = 2,
conv3d_block_channels: List[int] = [8, 16, 32],
Expand Down Expand Up @@ -88,7 +85,6 @@ def __init__(

assert (
len(conv3d_block_channels)
== len(conv3d_out_channels)
== len(conv3d_dropouts)
== len(upscale2d_block_channels)
== len(upscale2d_block_channels2)
Expand Down Expand Up @@ -119,13 +115,23 @@ def __init__(
self.out_relus = out_relus
self.relu_act = nn.ReLU()

# Infer number of channels after 3D-to-2D flattening at each stage from the z_in size
z_size = z_in
attention_in_channels = []
for pool_stride, conv3d_channels in zip(pool_z_strides, conv3d_block_channels):
attention_in_channels.append(conv3d_channels * z_size)
z_size = z_size // pool_stride
z_size -= max(0, 2 - pool_stride)
conv2d_in_channels = conv3d_block_channels[-1] * z_size
attention_in_channels = list(reversed(attention_in_channels))

# -- Input merge conv blocks --
self.merge_convs = nn.ModuleList([None] * n_in)
for i in range(n_in):
self.merge_convs[i] = nn.ModuleList(
[
Conv3dBlock(
conv3d_in_channels,
in_channels,
merge_block_channels[0],
3,
merge_block_depth,
Expand Down Expand Up @@ -207,7 +213,7 @@ def __init__(

# -- Decoder conv blocks --
self.attentions = nn.ModuleList([])
for c_att, c_conv in zip(attention_channels, reversed(conv3d_out_channels)):
for c_att, c_conv in zip(attention_channels, attention_in_channels):
self.attentions.append(
UNetAttentionConv(
c_conv, conv2d_block_channels[-1], c_att, 3, padding_mode, self.act, attention_activation, upsample_mode="bilinear"
Expand Down Expand Up @@ -246,7 +252,7 @@ def __init__(
for i in range(len(upscale2d_block_channels2)):
self.upscale2d_blocks2.append(
Conv2dBlock(
upscale2d_block_channels[i] + conv3d_out_channels[-(i + 1)],
upscale2d_block_channels[i] + attention_in_channels[i],
upscale2d_block_channels2[i],
3,
upscale2d_block_depth2,
Expand Down Expand Up @@ -312,12 +318,19 @@ def __init__(
def _flatten(self, x):
return x.permute(0, 1, 4, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3))

def forward(self, x: List[torch.Tensor]):
def forward(self, x: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Do forward computation.
Arguments:
x: Input AFM images of shape (batch, channels, )
x: Input AFM images of shape (batch, channels, x, y, z).
Returns:
Tuple (**outputs**, **attention_maps**), where
- **outputs** - Output arrays of shape ``(batch, out_convs_channels, x, y)`` or ``(batch, x, y)`` if
out_convs_channels == 1.
- **attention_maps** - Attention maps at each stage of skip-connections.
"""
assert len(x) == self.n_in

Expand Down Expand Up @@ -411,19 +424,18 @@ def __init__(
n_in = 2

super().__init__(
conv3d_in_channels=1,
conv2d_in_channels=192,
z_in=6,
n_in=n_in,
n_out=1,
in_channels=1,
merge_block_channels=[32],
merge_block_depth=2,
conv3d_out_channels=[288, 288, 384],
conv3d_dropouts=[0.0, 0.0, 0.0],
conv3d_block_channels=[48, 96, 192],
conv3d_block_depth=3,
conv2d_block_channels=[512],
conv2d_block_depth=3,
conv2d_dropouts=[0.0],
n_in=n_in,
n_out=1,
upscale2d_block_channels=[256, 128, 64],
upscale2d_block_depth=2,
upscale2d_block_channels2=[256, 128, 64],
Expand Down
6 changes: 3 additions & 3 deletions mlspm/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def add_rotation_reflection(
multiple: int = 2,
crop: Optional[Tuple[int]] = None,
per_batch_item: bool = False,
):
) -> Tuple[np.ndarray, np.ndarray]:
"""
Augment batch with random rotations and reflections.
Expand Down Expand Up @@ -295,7 +295,7 @@ def random_crop(
max_aspect: float = 2.0,
multiple: int = 8,
distribution: Literal["flat", "exp-log"] = "flat",
):
) -> Tuple[np.ndarray, np.ndarray]:
"""
Randomly crop images in a batch to a different size and aspect ratio.
Expand All @@ -309,7 +309,7 @@ def random_crop(
between (1, max_aspect) and half of time is flipped. If 'exp-log', then distribution is exp of log of uniform
distribution over (1/max_aspect, max_aspect). 'exp-log' is more biased towards square aspect ratios.
Returns:
Returns:
Tuple (**X**, **Y**), where
- **X** - Batch of cropped AFM images of shape ``(batch, x_new, y_new, z)``.
Expand Down
6 changes: 3 additions & 3 deletions papers/ed-afm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,15 @@ def run(cfg):
pred, _ = model(X)

# Data back to host
X = X.squeeze(1).cpu().numpy()
X = [x.squeeze(1).cpu().numpy() for x in X]
pred = [p.cpu().numpy() for p in pred]
ref = [r.cpu().numpy() for r in ref]

# Save xyzs
utils.batch_write_xyzs(xyz, outdir=pred_dir, start_ind=counter)

# Visualize input AFM images and predictions
vis.make_input_plots([X], outdir=pred_dir, start_ind=counter)
vis.make_input_plots(X, outdir=pred_dir, start_ind=counter)
vis.make_prediction_plots(pred, ref, descriptors=cfg["loss_labels"], outdir=pred_dir, start_ind=counter)

counter += len(X[0])
Expand All @@ -336,7 +336,7 @@ def run(cfg):


if __name__ == "__main__":

# Get config
cfg = parse_args()
run_dir = Path(cfg["run_dir"])
Expand Down
44 changes: 23 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,27 +283,29 @@ def test_AttentionUnet():
torch.manual_seed(0)

device = "cpu"
model = AttentionUNet(
conv3d_in_channels=1,
conv2d_in_channels=64,
conv3d_out_channels=[80, 80, 128],
n_in=2,
n_out=3,
merge_block_channels=[8],
conv3d_block_channels=[8, 16, 32],
conv2d_block_channels=[128],
attention_channels= [16, 32, 24],
pool_z_strides=[2, 1, 2],
device=device
)

x = [torch.rand((5, 1, 128, 128, 10)).to(device), torch.rand((5, 1, 128, 128, 10)).to(device)]
ys, att = model(x)
for z_in in range(6, 15):

model = AttentionUNet(
z_in=z_in,
n_in=2,
n_out=3,
in_channels=1,
merge_block_channels=[2],
conv3d_block_channels=[2, 4, 8],
conv2d_block_channels=[32],
attention_channels= [4, 8, 6],
pool_z_strides=[2, 1, 2],
device=device
)

assert len(ys) == 3
assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128)
x = [torch.rand((5, 1, 128, 128, z_in)).to(device), torch.rand((5, 1, 128, 128, z_in)).to(device)]
ys, att = model(x)

assert len(ys) == 3
assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128)

assert len(att) == 3
assert att[0].shape == (5, 32, 32)
assert att[1].shape == (5, 64, 64)
assert att[2].shape == (5, 128, 128)
assert len(att) == 3
assert att[0].shape == (5, 32, 32)
assert att[1].shape == (5, 64, 64)
assert att[2].shape == (5, 128, 128)

0 comments on commit 31789b6

Please sign in to comment.