From 31789b640df30f4571a1f2b21993ccfb9586daa9 Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Wed, 7 Feb 2024 13:27:29 +0200 Subject: [PATCH] Infer channels from z size in AttentionUnet + test + some fixes --- .gitignore | 1 + mlspm/image/models.py | 48 ++++++++++++++++++++++++++---------------- mlspm/preprocessing.py | 6 +++--- papers/ed-afm/train.py | 6 +++--- tests/test_models.py | 44 ++++++++++++++++++++------------------ 5 files changed, 60 insertions(+), 45 deletions(-) diff --git a/.gitignore b/.gitignore index 3c0e4df..e37d449 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,4 @@ dmypy.json # Other molecules +*.csv diff --git a/mlspm/image/models.py b/mlspm/image/models.py index a5c8f57..27c0d47 100644 --- a/mlspm/image/models.py +++ b/mlspm/image/models.py @@ -1,4 +1,3 @@ -from os import PathLike from typing import List, Literal, Optional, Tuple import torch @@ -21,11 +20,10 @@ class AttentionUNet(nn.Module): by concatenating along channel axis. Arguments: - conv3d_in_channels: Number of channels in input. - conv2d_in_channels: Number of channels in first 2D conv layer after flattening 3D to 2D. - conv3d_out_channels: Number of channels after 3D-to-2D flattening after each 3D conv block. Depends on input z size. + z_in: Size of input array in the z-dimension. n_in: Number of input 3D images. n_out: Number of output 2D maps. + in_channels: Number of channels in input array. merge_block_channels: Number of channels in input merging 3D conv blocks. merge_block_depth: Number of layers in each merge conv block. conv3d_block_channels: Number channels in 3D conv blocks. @@ -54,11 +52,10 @@ class AttentionUNet(nn.Module): def __init__( self, - conv3d_in_channels: int, - conv2d_in_channels: int, - conv3d_out_channels: List[int], + z_in: int = 10, n_in: int = 1, n_out: int = 3, + in_channels: int = 1, merge_block_channels: List[int] = [8], merge_block_depth: int = 2, conv3d_block_channels: List[int] = [8, 16, 32], @@ -88,7 +85,6 @@ def __init__( assert ( len(conv3d_block_channels) - == len(conv3d_out_channels) == len(conv3d_dropouts) == len(upscale2d_block_channels) == len(upscale2d_block_channels2) @@ -119,13 +115,23 @@ def __init__( self.out_relus = out_relus self.relu_act = nn.ReLU() + # Infer number of channels after 3D-to-2D flattening at each stage from the z_in size + z_size = z_in + attention_in_channels = [] + for pool_stride, conv3d_channels in zip(pool_z_strides, conv3d_block_channels): + attention_in_channels.append(conv3d_channels * z_size) + z_size = z_size // pool_stride + z_size -= max(0, 2 - pool_stride) + conv2d_in_channels = conv3d_block_channels[-1] * z_size + attention_in_channels = list(reversed(attention_in_channels)) + # -- Input merge conv blocks -- self.merge_convs = nn.ModuleList([None] * n_in) for i in range(n_in): self.merge_convs[i] = nn.ModuleList( [ Conv3dBlock( - conv3d_in_channels, + in_channels, merge_block_channels[0], 3, merge_block_depth, @@ -207,7 +213,7 @@ def __init__( # -- Decoder conv blocks -- self.attentions = nn.ModuleList([]) - for c_att, c_conv in zip(attention_channels, reversed(conv3d_out_channels)): + for c_att, c_conv in zip(attention_channels, attention_in_channels): self.attentions.append( UNetAttentionConv( c_conv, conv2d_block_channels[-1], c_att, 3, padding_mode, self.act, attention_activation, upsample_mode="bilinear" @@ -246,7 +252,7 @@ def __init__( for i in range(len(upscale2d_block_channels2)): self.upscale2d_blocks2.append( Conv2dBlock( - upscale2d_block_channels[i] + conv3d_out_channels[-(i + 1)], + upscale2d_block_channels[i] + attention_in_channels[i], upscale2d_block_channels2[i], 3, upscale2d_block_depth2, @@ -312,12 +318,19 @@ def __init__( def _flatten(self, x): return x.permute(0, 1, 4, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3)) - def forward(self, x: List[torch.Tensor]): + def forward(self, x: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Do forward computation. Arguments: - x: Input AFM images of shape (batch, channels, ) + x: Input AFM images of shape (batch, channels, x, y, z). + + Returns: + Tuple (**outputs**, **attention_maps**), where + + - **outputs** - Output arrays of shape ``(batch, out_convs_channels, x, y)`` or ``(batch, x, y)`` if + out_convs_channels == 1. + - **attention_maps** - Attention maps at each stage of skip-connections. """ assert len(x) == self.n_in @@ -411,19 +424,18 @@ def __init__( n_in = 2 super().__init__( - conv3d_in_channels=1, - conv2d_in_channels=192, + z_in=6, + n_in=n_in, + n_out=1, + in_channels=1, merge_block_channels=[32], merge_block_depth=2, - conv3d_out_channels=[288, 288, 384], conv3d_dropouts=[0.0, 0.0, 0.0], conv3d_block_channels=[48, 96, 192], conv3d_block_depth=3, conv2d_block_channels=[512], conv2d_block_depth=3, conv2d_dropouts=[0.0], - n_in=n_in, - n_out=1, upscale2d_block_channels=[256, 128, 64], upscale2d_block_depth=2, upscale2d_block_channels2=[256, 128, 64], diff --git a/mlspm/preprocessing.py b/mlspm/preprocessing.py index be75b8d..5c5be09 100644 --- a/mlspm/preprocessing.py +++ b/mlspm/preprocessing.py @@ -231,7 +231,7 @@ def add_rotation_reflection( multiple: int = 2, crop: Optional[Tuple[int]] = None, per_batch_item: bool = False, -): +) -> Tuple[np.ndarray, np.ndarray]: """ Augment batch with random rotations and reflections. @@ -295,7 +295,7 @@ def random_crop( max_aspect: float = 2.0, multiple: int = 8, distribution: Literal["flat", "exp-log"] = "flat", -): +) -> Tuple[np.ndarray, np.ndarray]: """ Randomly crop images in a batch to a different size and aspect ratio. @@ -309,7 +309,7 @@ def random_crop( between (1, max_aspect) and half of time is flipped. If 'exp-log', then distribution is exp of log of uniform distribution over (1/max_aspect, max_aspect). 'exp-log' is more biased towards square aspect ratios. - Returns: + Returns: Tuple (**X**, **Y**), where - **X** - Batch of cropped AFM images of shape ``(batch, x_new, y_new, z)``. diff --git a/papers/ed-afm/train.py b/papers/ed-afm/train.py index a19cd5a..9f14dbe 100644 --- a/papers/ed-afm/train.py +++ b/papers/ed-afm/train.py @@ -315,7 +315,7 @@ def run(cfg): pred, _ = model(X) # Data back to host - X = X.squeeze(1).cpu().numpy() + X = [x.squeeze(1).cpu().numpy() for x in X] pred = [p.cpu().numpy() for p in pred] ref = [r.cpu().numpy() for r in ref] @@ -323,7 +323,7 @@ def run(cfg): utils.batch_write_xyzs(xyz, outdir=pred_dir, start_ind=counter) # Visualize input AFM images and predictions - vis.make_input_plots([X], outdir=pred_dir, start_ind=counter) + vis.make_input_plots(X, outdir=pred_dir, start_ind=counter) vis.make_prediction_plots(pred, ref, descriptors=cfg["loss_labels"], outdir=pred_dir, start_ind=counter) counter += len(X[0]) @@ -336,7 +336,7 @@ def run(cfg): if __name__ == "__main__": - + # Get config cfg = parse_args() run_dir = Path(cfg["run_dir"]) diff --git a/tests/test_models.py b/tests/test_models.py index 4092bc6..af7acc8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -283,27 +283,29 @@ def test_AttentionUnet(): torch.manual_seed(0) device = "cpu" - model = AttentionUNet( - conv3d_in_channels=1, - conv2d_in_channels=64, - conv3d_out_channels=[80, 80, 128], - n_in=2, - n_out=3, - merge_block_channels=[8], - conv3d_block_channels=[8, 16, 32], - conv2d_block_channels=[128], - attention_channels= [16, 32, 24], - pool_z_strides=[2, 1, 2], - device=device - ) - x = [torch.rand((5, 1, 128, 128, 10)).to(device), torch.rand((5, 1, 128, 128, 10)).to(device)] - ys, att = model(x) + for z_in in range(6, 15): + + model = AttentionUNet( + z_in=z_in, + n_in=2, + n_out=3, + in_channels=1, + merge_block_channels=[2], + conv3d_block_channels=[2, 4, 8], + conv2d_block_channels=[32], + attention_channels= [4, 8, 6], + pool_z_strides=[2, 1, 2], + device=device + ) - assert len(ys) == 3 - assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128) + x = [torch.rand((5, 1, 128, 128, z_in)).to(device), torch.rand((5, 1, 128, 128, z_in)).to(device)] + ys, att = model(x) + + assert len(ys) == 3 + assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128) - assert len(att) == 3 - assert att[0].shape == (5, 32, 32) - assert att[1].shape == (5, 64, 64) - assert att[2].shape == (5, 128, 128) + assert len(att) == 3 + assert att[0].shape == (5, 32, 32) + assert att[1].shape == (5, 64, 64) + assert att[2].shape == (5, 128, 128)