From fac368ed4434de439cc40536b84d32aa69d8efa2 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Sat, 5 Oct 2024 12:12:09 +0200 Subject: [PATCH 1/4] Add PUNet logger --- torch_em/self_training/logger.py | 29 +++++++++++++++++++ .../probabilistic_unet_trainer.py | 13 +++++---- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py index 3ecbfa9d..88b50983 100644 --- a/torch_em/self_training/logger.py +++ b/torch_em/self_training/logger.py @@ -82,3 +82,32 @@ def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) if gt_metric is not None: self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step) + + +class ProbabilisticUNetTrainerLogger(torch_em.trainer.logger_base.TorchEmLogger): + def __init__(self, trainer, save_root, **unused_kwargs): + super().__init__(trainer, save_root) + self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ + os.path.join(save_root, "logs", trainer.name) + os.makedirs(self.log_dir, exist_ok=True) + + self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) + self.log_image_interval = trainer.log_image_interval + + def add_image(self, x, y, samples, name, step): + # NOTE: we only show the first tensor per batch for all images + self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) + self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) + sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) + self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) + + def log_train(self, step, loss, lr, x, y, samples): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + if step % self.log_image_interval == 0: + self.add_image(x, y, samples, "train", step) + + def log_validation(self, step, metric, loss, x, y, samples): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + self.add_image(x, y, samples, "validation", step) diff --git a/torch_em/self_training/probabilistic_unet_trainer.py b/torch_em/self_training/probabilistic_unet_trainer.py index 2af1f1e7..85f98476 100644 --- a/torch_em/self_training/probabilistic_unet_trainer.py +++ b/torch_em/self_training/probabilistic_unet_trainer.py @@ -21,12 +21,12 @@ class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer): """ def __init__( - self, - clipping_value=None, - prior_samples=16, - loss=None, - loss_and_metric=None, - **kwargs + self, + clipping_value=None, + prior_samples=16, + loss=None, + loss_and_metric=None, + **kwargs ): super().__init__(loss=loss, metric=DummyLoss(), **kwargs) assert loss, loss_and_metric is not None @@ -76,6 +76,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + # We only sample if we log images in this iteration. samples = self._sample() if self._iteration % self.log_image_interval == 0 else None self.logger.log_train(self._iteration, loss, lr, x, y, samples) From 56cab447dce5faa708af3742eb4b0bc4c18a5613 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 5 Oct 2024 13:58:04 +0200 Subject: [PATCH 2/4] Add logger to submodule calls --- torch_em/self_training/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py index 89227a0d..d96e8ef4 100644 --- a/torch_em/self_training/__init__.py +++ b/torch_em/self_training/__init__.py @@ -1,4 +1,4 @@ -from .logger import SelfTrainingTensorboardLogger +from .logger import SelfTrainingTensorboardLogger, ProbabilisticUNetTrainerLogger from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric, ProbabilisticUNetLoss, \ ProbabilisticUNetLossAndMetric from .mean_teacher import MeanTeacherTrainer From fd4e7ccb11abef4832e190b26d92595b859921b0 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 5 Oct 2024 18:01:01 +0200 Subject: [PATCH 3/4] Fix linting issues --- torch_em/model/probabilistic_unet.py | 84 ++++++++++++++-------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/torch_em/model/probabilistic_unet.py b/torch_em/model/probabilistic_unet.py index 70084f88..280d0452 100644 --- a/torch_em/model/probabilistic_unet.py +++ b/torch_em/model/probabilistic_unet.py @@ -1,6 +1,8 @@ # This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet # The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +from typing import Union + import numpy as np import torch @@ -21,7 +23,7 @@ def truncated_normal_(tensor, mean=0, std=1): def init_weights(m): - if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: + if isinstance(m, Union[nn.Conv2d, nn.ConvTranspose2d]): nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') # nn.init.normal_(m.weight, std=0.001) # nn.init.normal_(m.bias, std=0.001) @@ -29,7 +31,7 @@ def init_weights(m): def init_weights_orthogonal_normal(m): - if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: + if isinstance(m, Union[nn.Conv2d, nn.ConvTranspose2d]): nn.init.orthogonal_(m.weight) truncated_normal_(m.bias, mean=0, std=0.001) # nn.init.normal_(m.bias, std=0.001) @@ -126,13 +128,13 @@ def __init__( self.name = 'Prior' self.encoder = Encoder( - self.input_channels, - self.num_filters, - self.no_convs_per_block, - initializers, - posterior=self.posterior, - num_classes=num_classes - ) + self.input_channels, + self.num_filters, + self.no_convs_per_block, + initializers, + posterior=self.posterior, + num_classes=num_classes + ) self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1) self.show_img = 0 @@ -254,8 +256,8 @@ def tile(self, a, dim, n_tile): repeat_idx[dim] = n_tile a = a.repeat(*(repeat_idx)) order_index = torch.LongTensor( - np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) - ).to(self.device) + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ).to(self.device) return torch.index_select(a, dim, order_index) def forward(self, feature_map, z): @@ -335,40 +337,40 @@ def __init__( self.device = device self.unet = UNet2d( - in_channels=self.input_channels, - out_channels=None, - depth=len(self.num_filters), - initial_features=num_filters[0] - ).to(self.device) + in_channels=self.input_channels, + out_channels=None, + depth=len(self.num_filters), + initial_features=num_filters[0] + ).to(self.device) self.prior = AxisAlignedConvGaussian( - self.input_channels, - self.num_filters, - self.no_convs_per_block, - self.latent_dim, - self.initializers - ).to(self.device) + self.input_channels, + self.num_filters, + self.no_convs_per_block, + self.latent_dim, + self.initializers + ).to(self.device) self.posterior = AxisAlignedConvGaussian( - self.input_channels, - self.num_filters, - self.no_convs_per_block, - self.latent_dim, - self.initializers, - posterior=True, - num_classes=num_classes - ).to(self.device) + self.input_channels, + self.num_filters, + self.no_convs_per_block, + self.latent_dim, + self.initializers, + posterior=True, + num_classes=num_classes + ).to(self.device) self.fcomb = Fcomb( - self.num_filters, - self.latent_dim, - self.input_channels, - self.num_classes, - self.no_convs_fcomb, - {'w': 'orthogonal', 'b': 'normal'}, - use_tile=True, - device=self.device - ).to(self.device) + self.num_filters, + self.latent_dim, + self.input_channels, + self.num_classes, + self.no_convs_fcomb, + {'w': 'orthogonal', 'b': 'normal'}, + use_tile=True, + device=self.device + ).to(self.device) def _check_shape(self, patch): spatial_shape = tuple(patch.shape)[2:] @@ -449,8 +451,8 @@ def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=Fa z_posterior = self.posterior_latent_space.rsample() self.kl = torch.mean( - self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) - ) + self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior) + ) # Here we use the posterior sample sampled above self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, From 899aec30847f5b28da983e3c52a8b7f6df583a6f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 6 Oct 2024 00:04:40 +0200 Subject: [PATCH 4/4] Add support for multi rater annotators --- torch_em/model/probabilistic_unet.py | 24 +++++++++++-------- torch_em/self_training/loss.py | 23 ++++++++++++++++-- .../probabilistic_unet_trainer.py | 2 ++ 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/torch_em/model/probabilistic_unet.py b/torch_em/model/probabilistic_unet.py index 280d0452..5bbf6537 100644 --- a/torch_em/model/probabilistic_unet.py +++ b/torch_em/model/probabilistic_unet.py @@ -200,7 +200,6 @@ def __init__( num_filters, latent_dim, num_output_channels, - num_classes, no_convs_fcomb, initializers, use_tile=True, @@ -209,8 +208,7 @@ def __init__( super().__init__() - self.num_channels = num_output_channels - self.num_classes = num_classes + self.num_output_channels = num_output_channels self.channel_axis = 1 self.spatial_axes = [2, 3] self.num_filters = num_filters @@ -237,7 +235,7 @@ def __init__( self.layers = nn.Sequential(*layers) - self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) + self.last_layer = nn.Conv2d(self.num_filters[0], self.num_output_channels, kernel_size=1) if initializers['w'] == 'orthogonal': self.layers.apply(init_weights_orthogonal_normal) @@ -284,16 +282,18 @@ class ProbabilisticUNet(nn.Module): The following elements are initialized to get our desired network: input_channels: the number of channels in the image (1 for grayscale and 3 for RGB) - num_classes: the number of classes to predict + output_channels: the number of channels to predict. + num_classes: the number of classes (raters) for the posterior num_filters: is a list consisting of the amount of filters layer latent_dim: dimension of the latent space - no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior + no_convs_per_block: no convs per block in the (convolutional) encoder of prior and posterior beta: KL and reconstruction loss are weighted using a KL weighting factor (β) consensus_masking: activates consensus masking in the reconstruction loss rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss) Parameters: input_channels [int] - (default: 1) + output_channels [int] - (default: 1) num_classes [int] - (default: 1) num_filters [list] - (default: [32, 64, 128, 192]) latent_dim [int] - (default: 6) @@ -307,6 +307,7 @@ class ProbabilisticUNet(nn.Module): def __init__( self, input_channels=1, + output_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, @@ -320,6 +321,7 @@ def __init__( super().__init__() self.input_channels = input_channels + self.output_channels = output_channels self.num_classes = num_classes self.num_filters = num_filters self.latent_dim = latent_dim @@ -364,8 +366,7 @@ def __init__( self.fcomb = Fcomb( self.num_filters, self.latent_dim, - self.input_channels, - self.num_classes, + self.output_channels, self.no_convs_fcomb, {'w': 'orthogonal', 'b': 'normal'}, use_tile=True, @@ -455,8 +456,11 @@ def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=Fa ) # Here we use the posterior sample sampled above - self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, - calculate_posterior=False, z_posterior=z_posterior) + self.reconstruction = self.reconstruct( + use_posterior_mean=reconstruct_posterior_mean, + calculate_posterior=False, + z_posterior=z_posterior + ) if self.consensus_masking is True and consm is not None: reconstruction_loss = criterion(self.reconstruction * consm, segm * consm) diff --git a/torch_em/self_training/loss.py b/torch_em/self_training/loss.py index 963e95d8..1e972529 100644 --- a/torch_em/self_training/loss.py +++ b/torch_em/self_training/loss.py @@ -78,13 +78,19 @@ class ProbabilisticUNetLoss(nn.Module): # TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.) loss [nn.Module] - the loss function to be used. (default: None) """ - def __init__(self, loss=None): + def __init__(self, loss=None, output_channels=None): super().__init__() self.loss = loss + self.output_channels = output_channels def __call__(self, model, input_, labels, label_filter=None): model.forward(input_, labels) + # NOTE: 'output_channels' ensures to compute loss over only one label set (in case of multi-rater annotation). + # In the current experiment, we consider the first label out of the bunch. + if self.output_channels is not None: + labels = labels[:, :self.output_channels, ...] + if self.loss is None: elbo = model.elbo(labels, label_filter) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ @@ -105,16 +111,29 @@ class ProbabilisticUNetLossAndMetric(nn.Module): activation [nn.Module, callable] - the activation function to be applied to the prediction before evaluating the average predictions. (default: None) """ - def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16): + def __init__( + self, + loss=None, + metric=DiceLoss(), + activation=torch.nn.Sigmoid(), + prior_samples=16, + output_channels=None, + ): super().__init__() self.activation = activation self.metric = metric self.loss = loss self.prior_samples = prior_samples + self.output_channels = output_channels def __call__(self, model, input_, labels, label_filter=None): model.forward(input_, labels) + # NOTE: 'output_channels' ensures to compute loss over only one label set (in case of multi-rater annotation). + # In the current experiment, we consider the first label out of the bunch. + if self.output_channels is not None: + labels = labels[:, :self.output_channels, ...] + if self.loss is None: elbo = model.elbo(labels, label_filter) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ diff --git a/torch_em/self_training/probabilistic_unet_trainer.py b/torch_em/self_training/probabilistic_unet_trainer.py index 85f98476..56a50737 100644 --- a/torch_em/self_training/probabilistic_unet_trainer.py +++ b/torch_em/self_training/probabilistic_unet_trainer.py @@ -78,6 +78,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): lr = [pm["lr"] for pm in self.optimizer.param_groups][0] # We only sample if we log images in this iteration. samples = self._sample() if self._iteration % self.log_image_interval == 0 else None + y = y[:, :self.model.output_channels, ...] self.logger.log_train(self._iteration, loss, lr, x, y, samples) self._iteration += 1 @@ -110,6 +111,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: samples = self._sample() + y = y[:, :self.model.output_channels, ...] self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples) return metric_val