diff --git a/.devcontainer/onCreateCommand.sh b/.devcontainer/onCreateCommand.sh index b2fb94354..d2f104a51 100755 --- a/.devcontainer/onCreateCommand.sh +++ b/.devcontainer/onCreateCommand.sh @@ -6,4 +6,4 @@ pip install wheel pip install openvino-dev==2023.0.1 # [OPTIONAL] to generate optimized models for inference pip install mlcube_docker # [OPTIONAL] to deploy GaNDLF models as MLCube-compliant Docker containers pip install medmnist==2.1.0 -pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu +pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu diff --git a/.devcontainer/postCreateCommand.sh b/.devcontainer/postCreateCommand.sh index 8428eb5d7..163341712 100755 --- a/.devcontainer/postCreateCommand.sh +++ b/.devcontainer/postCreateCommand.sh @@ -6,7 +6,7 @@ # if runnning on a GPU machine, install the GPU version of pytorch if command -v nvidia-smi &> /dev/null then - pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 + pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu121 fi pip install -e . diff --git a/.github/workflows/dependencies/action.yml b/.github/workflows/dependencies/action.yml index 32f862b54..9aa2d2dd2 100644 --- a/.github/workflows/dependencies/action.yml +++ b/.github/workflows/dependencies/action.yml @@ -100,5 +100,5 @@ runs: python -m pip install --upgrade pip==24.0 python -m pip install wheel python -m pip install openvino-dev==2023.0.1 mlcube_docker - pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu + pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu pip install -e . diff --git a/Dockerfile-CPU b/Dockerfile-CPU index d9c87de25..be93294c1 100644 --- a/Dockerfile-CPU +++ b/Dockerfile-CPU @@ -9,7 +9,7 @@ RUN add-apt-repository ppa:deadsnakes/ppa RUN apt-get update && apt-get install -y python3.9 python3-pip libjpeg8-dev zlib1g-dev python3-dev libpython3.9-dev libffi-dev libgl1 RUN python3.9 -m pip install --upgrade pip==24.0 # EXPLICITLY install cpu versions of torch/torchvision (not all versions have +cpu modes on PyPI...) -RUN python3.9 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu +RUN python3.9 -m pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu RUN python3.9 -m pip install openvino-dev==2023.0.1 opencv-python-headless mlcube_docker # Do some dependency installation separately here to make layer caching more efficient @@ -32,7 +32,7 @@ CMD run # See https://github.com/hexops/dockerfile as a best practices guide. #RUN addgroup --gid 10001 --system nonroot \ # && adduser --uid 10000 --system --ingroup nonroot --home /home/nonroot nonroot -# +# #USER nonroot # Prepare the container for possible model embedding later. diff --git a/Dockerfile-CUDA11.8 b/Dockerfile-CUDA11.8 index 6b06fcda5..84ecd3ab9 100644 --- a/Dockerfile-CUDA11.8 +++ b/Dockerfile-CUDA11.8 @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y software-properties-common RUN add-apt-repository ppa:deadsnakes/ppa RUN apt-get update && apt-get install -y python3.9 python3-pip libjpeg8-dev zlib1g-dev python3-dev libpython3.9-dev libffi-dev libgl1 RUN python3.9 -m pip install --upgrade pip==24.0 -RUN python3.9 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 +RUN python3.9 -m pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu118 RUN python3.9 -m pip install openvino-dev==2023.0.1 opencv-python-headless mlcube_docker # Do some dependency installation separately here to make layer caching more efficient diff --git a/Dockerfile-CUDA12.1 b/Dockerfile-CUDA12.1 index 4da63a335..1807b5562 100644 --- a/Dockerfile-CUDA12.1 +++ b/Dockerfile-CUDA12.1 @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y software-properties-common RUN add-apt-repository ppa:deadsnakes/ppa RUN apt-get update && apt-get install -y python3.9 python3-pip libjpeg8-dev zlib1g-dev python3-dev libpython3.9-dev libffi-dev libgl1 RUN python3.9 -m pip install --upgrade pip==24.0 -RUN python3.9 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 +RUN python3.9 -m pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu121 RUN python3.9 -m pip install openvino-dev==2023.0.1 opencv-python-headless mlcube_docker # Do some dependency installation separately here to make layer caching more efficient diff --git a/Dockerfile-ROCm b/Dockerfile-ROCm index 5d0fb7450..60cd7f5a3 100644 --- a/Dockerfile-ROCm +++ b/Dockerfile-ROCm @@ -1,4 +1,4 @@ -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch +FROM rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch LABEL github="https://github.com/mlcommons/GaNDLF" LABEL docs="https://mlcommons.github.io/GaNDLF/" LABEL version=1.0 @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y software-properties-common RUN add-apt-repository ppa:deadsnakes/ppa RUN apt-get update && apt-get install -y python3.9 python3-pip libjpeg8-dev zlib1g-dev python3-dev libpython3.9-dev libffi-dev libgl1 RUN python3.9 -m pip install --upgrade pip==24.0 -RUN python3.9 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/rocm6.0 +RUN python3.9 -m pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/rocm6.0 RUN python3.9 -m pip install --upgrade pip && python3.9 -m pip install openvino-dev==2023.0.1 opencv-python-headless mlcube_docker RUN apt-get update && apt-get install -y libgl1 @@ -36,7 +36,7 @@ CMD run # See https://github.com/hexops/dockerfile as a best practices guide. #RUN addgroup --gid 10001 --system nonroot \ # && adduser --uid 10000 --system --ingroup nonroot --home /home/nonroot nonroot -# +# #USER nonroot # Prepare the container for possible model embedding later. diff --git a/GANDLF/compute/inference_loop.py b/GANDLF/compute/inference_loop.py index 2c9c0230a..6910b5974 100644 --- a/GANDLF/compute/inference_loop.py +++ b/GANDLF/compute/inference_loop.py @@ -15,7 +15,7 @@ from skimage.io import imsave from tqdm import tqdm from torch.cuda.amp import autocast -import tiffslide as openslide +import openslide from GANDLF.data import get_testing_loader from GANDLF.utils import ( best_model_path_end, @@ -344,11 +344,13 @@ def inference_loop( ) cv2.imwrite(file_to_write, heatmaps[key]) - os_image_array = os_image.read_region( - (0, 0), - parameters["slide_level"], - (level_width, level_height), - as_array=True, + # this is needed because openslide returns an RGBA image + os_image_array = np.asarray( + os_image.read_region( + (0, 0), + parameters["slide_level"], + (level_width, level_height), + ).convert("RGB") ) blended_image = cv2.addWeighted( os_image_array, diff --git a/GANDLF/data/inference_dataloader_histopath.py b/GANDLF/data/inference_dataloader_histopath.py index f4380c412..f24e88a67 100644 --- a/GANDLF/data/inference_dataloader_histopath.py +++ b/GANDLF/data/inference_dataloader_histopath.py @@ -1,7 +1,7 @@ import os from typing import Optional import numpy as np -import tiffslide +import openslide from GANDLF.data.patch_miner.opm.utils import get_patch_size_in_microns, tissue_mask from skimage.transform import resize from torch.utils.data.dataset import Dataset @@ -51,7 +51,7 @@ def __init__( self._stride_size = get_patch_size_in_microns(wsi_path, self._stride_size) self._selected_level = selected_level self._mask_level = mask_level - self._os_image = tiffslide.open_slide(os.path.join(self._wsi_path)) + self._os_image = openslide.open_slide(os.path.join(self._wsi_path)) self._points = [] self._basic_preprocessing() @@ -61,11 +61,13 @@ def _basic_preprocessing(self): try: mask_xdim, mask_ydim = self._os_image.level_dimensions[self._mask_level] mask = get_tissue_mask( - self._os_image.read_region( - (0, 0), self._mask_level, (mask_xdim, mask_ydim), as_array=True + # this is needed because openslide returns an RGBA image + np.asarray( + self._os_image.read_region( + (0, 0), self._mask_level, (mask_xdim, mask_ydim) + ).convert("RGB") ) ) - if self._selected_level != self._mask_level: mask = resize(mask, (height, width)) mask = (mask > 0).astype(np.ubyte) @@ -134,9 +136,10 @@ def __getitem__(self, idx): (x_loc, y_loc), self._selected_level, (self._patch_size[0], self._patch_size[1]), - as_array=True, - ) + # as_array=True, openslide-python returns an RGBA PIL image + ).convert("RGB") + patch = np.asarray(patch) # convert the image to ndarray # this is to ensure that channels come at the beginning patch = patch.transpose([2, 0, 1]) # this is to ensure that we always have a z-stack before applying any torchio transforms diff --git a/GANDLF/data/patch_miner/opm/patch.py b/GANDLF/data/patch_miner/opm/patch.py index dc2edb16f..43ba5a09b 100644 --- a/GANDLF/data/patch_miner/opm/patch.py +++ b/GANDLF/data/patch_miner/opm/patch.py @@ -46,7 +46,9 @@ def read_patch(self): return np.asarray( self.slide_object.read_region( (self.coordinates[1], self.coordinates[0]), self.level, self.size - ) + ).convert( + "RGB" + ) # openslide-python returns an RGBA PIL image ) def copy(self): diff --git a/GANDLF/data/patch_miner/opm/patch_manager.py b/GANDLF/data/patch_miner/opm/patch_manager.py index d32e769b5..280f6af47 100644 --- a/GANDLF/data/patch_miner/opm/patch_manager.py +++ b/GANDLF/data/patch_miner/opm/patch_manager.py @@ -7,7 +7,7 @@ from tqdm import tqdm from pathlib import Path import pandas as pd -import tiffslide +import openslide class PatchManager: @@ -41,7 +41,7 @@ def set_subjectID(self, subjectID): def set_slide_path(self, filename): self.img_path = filename self.img_path = convert_to_tiff(self.img_path, self.output_dir, "img") - self.slide_object = tiffslide.open_slide(self.img_path) + self.slide_object = openslide.open_slide(self.img_path) self.slide_dims = self.slide_object.dimensions def set_label_map(self, path): @@ -50,7 +50,7 @@ def set_label_map(self, path): @param path: path to label map. """ self.label_map = convert_to_tiff(path, self.output_dir, "mask") - self.label_map_object = tiffslide.open_slide(self.label_map) + self.label_map_object = openslide.open_slide(self.label_map) assert all( x == y for x, y in zip(self.label_map_object.dimensions, self.slide_dims) diff --git a/GANDLF/data/patch_miner/opm/utils.py b/GANDLF/data/patch_miner/opm/utils.py index ed2e57258..a0ebaf56f 100644 --- a/GANDLF/data/patch_miner/opm/utils.py +++ b/GANDLF/data/patch_miner/opm/utils.py @@ -17,7 +17,7 @@ # import matplotlib.pyplot as plt import yaml -import tiffslide +import openslide # RGB Masking (pen) constants RGB_RED_CHANNEL = 0 @@ -428,7 +428,7 @@ def generate_initial_mask(slide_path: str, scale: int) -> Tuple[np.ndarray, tupl Tuple[np.ndarray, tuple]: The valid mask and the real scale. """ # Open slide and get properties - slide = tiffslide.open_slide(slide_path) + slide = openslide.open_slide(slide_path) slide_dims = slide.dimensions # Call thumbnail for efficiency, calculate scale relative to whole slide @@ -505,26 +505,26 @@ def get_patch_size_in_microns( "Using mpp to calculate patch size for dimension {}".format(i) ) # only enter if "m" is present in patch size - input_slide = tiffslide.open_slide(input_slide_path) + input_slide = openslide.open_slide(input_slide_path) metadata = input_slide.properties if i == 0: for _property in [ - tiffslide.PROPERTY_NAME_MPP_X, + openslide.PROPERTY_NAME_MPP_X, "tiff.XResolution", "XResolution", ]: if _property in metadata: - magnification = metadata[_property] + magnification = float(metadata[_property]) magnification_prev = magnification break elif i == 1: for _property in [ - tiffslide.PROPERTY_NAME_MPP_Y, + openslide.PROPERTY_NAME_MPP_Y, "tiff.YResolution", "YResolution", ]: if _property in metadata: - magnification = metadata[_property] + magnification = float(metadata[_property]) break if magnification == -1: # if y-axis data is missing, use x-axis data diff --git a/GANDLF/entrypoints/run.py b/GANDLF/entrypoints/run.py index 842c47b5d..f202f2b9f 100644 --- a/GANDLF/entrypoints/run.py +++ b/GANDLF/entrypoints/run.py @@ -5,9 +5,7 @@ import argparse import ast -# import traceback from typing import Optional - from deprecated import deprecated import click diff --git a/GANDLF/losses/hybrid.py b/GANDLF/losses/hybrid.py index ddf62fa01..f4c862606 100644 --- a/GANDLF/losses/hybrid.py +++ b/GANDLF/losses/hybrid.py @@ -1,5 +1,4 @@ import torch - from .segmentation import MCD_loss, FocalLoss from .regression import CCE_Generic, CE, CE_Logits diff --git a/GANDLF/losses/hybrid_new.py b/GANDLF/losses/hybrid_new.py new file mode 100644 index 000000000..4fa7edfcc --- /dev/null +++ b/GANDLF/losses/hybrid_new.py @@ -0,0 +1,21 @@ +from .regression_new import BinaryCrossEntropyLoss, BinaryCrossEntropyWithLogitsLoss +from .segmentation_new import MulticlassDiceLoss, MulticlassFocalLoss +from .loss_interface import AbstractHybridLoss + + +class DiceCrossEntropyLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), BinaryCrossEntropyLoss(self.params)] + + +class DiceCrossEntropyLossLogits(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [ + MulticlassDiceLoss(self.params), + BinaryCrossEntropyWithLogitsLoss(self.params), + ] + + +class DiceFocalLoss(AbstractHybridLoss): + def _initialize_all_loss_calculators(self): + return [MulticlassDiceLoss(self.params), MulticlassFocalLoss(self.params)] diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py new file mode 100644 index 000000000..e8459f41d --- /dev/null +++ b/GANDLF/losses/loss_interface.py @@ -0,0 +1,153 @@ +import torch +from torch import nn +from abc import ABC, abstractmethod +from typing import List + + +class AbstractLossFunction(nn.Module, ABC): + def __init__(self, params: dict): + nn.Module.__init__(self) + self.params = params + self.num_classes = len(params["model"]["class_list"]) + self._initialize_penalty_weights() + + def _initialize_penalty_weights(self): + default_penalty_weights = torch.ones(self.num_classes) + self.penalty_weights = self.params.get( + "penalty_weights", default_penalty_weights + ) + + @abstractmethod + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the loss function. To be implemented by child classes. + """ + + +class AbstractSegmentationLoss(AbstractLossFunction): + """ + Base class for loss funcions that are used for segmentation tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return 1 - loss_value + + def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor: + """ + Perform addtional operations on the loss value. Defaults to identity operation. + If needed, child classes can override this method. Useful in cases where + for example, the loss value needs to log-transformed or clipped. + """ + return loss + + @abstractmethod + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute loss for a pair of prediction and target tensors. To be implemented by child classes. + """ + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + + for class_idx in range(self.num_classes): + current_loss = self._compute_single_class_loss( + prediction, target, class_idx + ) + accumulated_loss += ( + self._optional_loss_operations(current_loss) + * self.penalty_weights[class_idx] + ) + + accumulated_loss /= self.num_classes + + return accumulated_loss + + +class AbstractRegressionLoss(AbstractLossFunction): + """ + Base class for loss functions that are used for regression and classification tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculator = self._initialize_loss_function_object() + self.reduction_method = self._initialize_reduction_method() + + def _initialize_reduction_method(self) -> str: + """ + Initialize the reduction method for the loss function. Defaults to 'mean'. + """ + loss_params = self.params["loss_function"] + reduction_method = "mean" + if isinstance(loss_params, dict): + reduction_method = loss_params.get("reduction", reduction_method) + assert reduction_method in [ + "mean", + "sum", + ], f"Invalid reduction method defined for loss function: {reduction_method}. Valid options are ['mean', 'sum']" + return reduction_method + + def _calculate_loss_for_single_class( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Calculate loss for a single class. To be implemented by child classes. + """ + return self.loss_calculator(prediction, target) + + @abstractmethod + def _initialize_loss_function_object(self) -> nn.modules.loss._Loss: + """ + Initialize the loss function object used in the forward method. Has to return + callable pytorch loss function object. + """ + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for class_idx in range(self.num_classes): + accumulated_loss += ( + self._calculate_loss_for_single_class( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + * self.penalty_weights[class_idx] + ) + + accumulated_loss /= self.num_classes + + return accumulated_loss + + +class AbstractHybridLoss(AbstractLossFunction): + """ + Base class for hybrid loss functions that are used for segmentation tasks. + """ + + def __init__(self, params: dict): + super().__init__(params) + self.loss_calculators = self._initialize_all_loss_calculators() + + @abstractmethod + def _initialize_all_loss_calculators(self) -> List[AbstractLossFunction]: + """ + Each hybrid loss should implement this method, creating all loss functions as a list that + will be used during the forward pass. + """ + pass + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + accumulated_loss = torch.tensor(0.0, device=prediction.device) + for loss_calculator in self._initialize_all_loss_calculators(): + accumulated_loss += loss_calculator(prediction, target) + + return accumulated_loss diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 6d74a33a2..4949bd9d2 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -1,8 +1,8 @@ from typing import Optional import torch import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from GANDLF.utils import one_hot +from torch.nn import CrossEntropyLoss def CEL( diff --git a/GANDLF/losses/regression_new.py b/GANDLF/losses/regression_new.py new file mode 100644 index 000000000..e9e0d5db0 --- /dev/null +++ b/GANDLF/losses/regression_new.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from .loss_interface import AbstractRegressionLoss + + +class CrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.CrossEntropyLoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCELoss(reduction=self.reduction_method) + + +class BinaryCrossEntropyWithLogitsLoss(AbstractRegressionLoss): + """ + This class computes the binary cross entropy loss with logits between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.BCEWithLogitsLoss(reduction=self.reduction_method) + + +class BaseLossWithScaledTarget(AbstractRegressionLoss): + """ + General interface for the loss functions requiring scaling of the target tensor. + """ + + def _initialize_scaling_factor(self): + loss_params: dict = self.params["loss_function"] + self.scaling_factor = loss_params.get("scaling_factor", 1.0) + if isinstance(loss_params, dict): + self.scaling_factor = loss_params.get("scaling_factor", self.scaling_factor) + return self.scaling_factor + + def _calculate_loss(self, prediction: torch.Tensor, target: torch.Tensor): + return self.loss_calculator(prediction, target * self.scaling_factor) + + +class L1Loss(BaseLossWithScaledTarget): + """ + This class computes the L1 loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.L1Loss(reduction=self.reduction_method) + + +class MSELoss(BaseLossWithScaledTarget): + """ + This class computes the mean squared error loss between two tensors. + """ + + def _initialize_loss_function_object(self): + return nn.MSELoss(reduction=self.reduction_method) diff --git a/GANDLF/losses/segmentation_new.py b/GANDLF/losses/segmentation_new.py new file mode 100644 index 000000000..4999686fe --- /dev/null +++ b/GANDLF/losses/segmentation_new.py @@ -0,0 +1,193 @@ +import sys +import torch +from .loss_interface import AbstractSegmentationLoss, AbstractLossFunction + + +class MulticlassDiceLoss(AbstractSegmentationLoss): + """ + This class computes the Dice loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Dice score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = prediction.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + +class MulticlassDiceLogLoss(MulticlassDiceLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassMCCLoss(AbstractSegmentationLoss): + """ + This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors. + """ + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute MCC score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed MCC score. + """ + tp = torch.sum(torch.mul(prediction, target)) + tn = torch.sum(torch.mul((1 - prediction), (1 - target))) + fp = torch.sum(torch.mul(prediction, (1 - target))) + fn = torch.sum(torch.mul((1 - prediction), target)) + + numerator = torch.mul(tp, tn) - torch.mul(fp, fn) + # Adding epsilon to the denominator to avoid divide-by-zero errors. + denominator = ( + torch.sqrt( + torch.add(tp, 1, fp) + * torch.add(tp, 1, fn) + * torch.add(tn, 1, fp) + * torch.add(tn, 1, fn) + ) + + torch.finfo(torch.float32).eps + ) + + return torch.div(numerator.sum(), denominator.sum()) + + +class MulticlassMCLLogLoss(MulticlassMCCLoss): + def _optional_loss_operations(self, loss): + return -torch.log( + loss + torch.finfo(torch.float32).eps + ) # epsilon for numerical stability + + +class MulticlassTverskyLoss(AbstractSegmentationLoss): + """ + This class computes the Tversky loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + loss_params = params["loss_function"] + self.alpha = 0.5 + self.beta = 0.5 + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.beta = loss_params.get("beta", self.beta) + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute Tversky score for a single class. + + Args: + prediction (torch.Tensor): Network's predicted segmentation mask + target (torch.Tensor): Target segmentation mask + + Returns: + torch.Tensor: The computed Tversky score. + """ + predicted_flat = prediction.contiguous().view(-1) + target_flat = target.contiguous().view(-1) + + true_positives = (predicted_flat * target_flat).sum() + false_positives = ((1 - target_flat) * predicted_flat).sum() + false_negatives = (target_flat * (1 - predicted_flat)).sum() + + numerator = true_positives + denominator = ( + true_positives + self.alpha * false_positives + self.beta * false_negatives + ) + loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min) + + return loss + + +class MulticlassFocalLoss(AbstractSegmentationLoss): + """ + This class computes the Focal loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + + self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") + loss_params = params["loss_function"] + self.alpha = 1.0 + self.gamma = 2.0 + self.output_aggregation = "sum" + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.gamma = loss_params.get("gamma", self.gamma) + self.output_aggregation = loss_params.get( + "size_average", + self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format + ) + assert self.output_aggregation in [ + "sum", + "mean", + ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute focal loss for a single class. It is based on the following formulas: + FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) + CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) + CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) + p_t = p if target = 1 else 1 - p + """ + ce_loss = self.ce_loss_helper(prediction, target) + p_t = torch.exp(-ce_loss) + loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss + return loss.sum() if self.output_aggregation == "sum" else loss.mean() + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return loss_value # no need to subtract from 1 in this case, hence the override + + +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean() diff --git a/GANDLF/optimizers/README.md b/GANDLF/optimizers/README.md index b12a61f12..8d1499fb3 100644 --- a/GANDLF/optimizers/README.md +++ b/GANDLF/optimizers/README.md @@ -3,10 +3,16 @@ ## Adding a new algorithm - For an optimizer defined in PyTorch [[ref](https://pytorch.org/docs/stable/optim.html#algorithms)], update the `GANDLF.optimizers.wrap_torch.py` submodule. -- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`. Ensure that it inherits from PyTorch's base optimizer class [[ref](https://pytorch.org/docs/stable/optim.html#base-class)] +- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`. +- For a third-party optimizer (i.e., where the code is available from an external source/repository): + - Add the relevant code under the `GANDLF.optimizers.thirdparty` submodule. + - Add a wrapper which takes in GaNDLF's `parameter` dictionary as input and creates a `torch.optim.Optimizer` object as output. + - Add the wrapper to the `GANDLF.optimizers.thirdparty.__init__.py` so that it can be called from `GANDLF.optimizers.__init__.py`. + - See `GANDLF.optimizers.thirdparty.adopy.py` as an example. - If a new dependency needs to be used, update GaNDLF's [`setup.py`](https://github.com/mlcommons/GaNDLF/blob/master/setup.py) with the new requirement. - Define a new submodule under `GANDLF.optimizers` as `GANDLF.optimizers.wrap_${package_name}.py`. - Ensure that the new algorithm is wrapped in a function which returns an object with the PyTorch optimizer type. Use any of the optimizers in `GANDLF.optimizers.wrap_torch.py` as an example. - Add the algorithm's identifier to `GANDLF.optimizers.__init__.global_optimizer_dict` with an appropriate key. - Call the new algorithm from the config using the `optimizer` key. -- [Update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests +- [If appropriate, please update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests +- All wrappers should return the type `from torch.optim.optimizer.Optimizer`. \ No newline at end of file diff --git a/GANDLF/optimizers/__init__.py b/GANDLF/optimizers/__init__.py index b59afb22f..4df3d0ec6 100644 --- a/GANDLF/optimizers/__init__.py +++ b/GANDLF/optimizers/__init__.py @@ -15,7 +15,7 @@ from .wrap_monai import novograd_wrapper -from .ademamix import ademamix_wrapper +from .thirdparty import ademamix_wrapper, lion_wrapper, adopt_wrapper global_optimizer_dict = { "sgd": sgd, @@ -32,6 +32,8 @@ "novograd": novograd_wrapper, "nadam": nadam, "ademamix": ademamix_wrapper, + "lion": lion_wrapper, + "adopt": adopt_wrapper, } @@ -49,9 +51,10 @@ def get_optimizer(params): # Retrieve the optimizer type from the input parameters optimizer_type = params["optimizer"]["type"] + assert ( + optimizer_type in global_optimizer_dict + ), f"Optimizer type {optimizer_type} not found" + # Create the optimizer instance using the specified type and input parameters - if optimizer_type in global_optimizer_dict: - optimizer_function = global_optimizer_dict[optimizer_type] - return optimizer_function(params) - else: - raise ValueError("Optimizer type %s not found" % optimizer_type) + optimizer_function = global_optimizer_dict[optimizer_type] + return optimizer_function(params) diff --git a/GANDLF/optimizers/thirdparty/__init__.py b/GANDLF/optimizers/thirdparty/__init__.py new file mode 100644 index 000000000..7b47ed60c --- /dev/null +++ b/GANDLF/optimizers/thirdparty/__init__.py @@ -0,0 +1,5 @@ +from .ademamix import ademamix_wrapper + +from .lion import lion_wrapper + +from .adopt import adopt_wrapper diff --git a/GANDLF/optimizers/ademamix.py b/GANDLF/optimizers/thirdparty/ademamix.py similarity index 100% rename from GANDLF/optimizers/ademamix.py rename to GANDLF/optimizers/thirdparty/ademamix.py diff --git a/GANDLF/optimizers/thirdparty/adopt.py b/GANDLF/optimizers/thirdparty/adopt.py new file mode 100644 index 000000000..115c61c00 --- /dev/null +++ b/GANDLF/optimizers/thirdparty/adopt.py @@ -0,0 +1,524 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torch.optim.optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _fused_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + DeviceDict, + Optimizer, + ParamsT, +) + + +__all__ = ["ADOPT", "adopt"] + + +class ADOPT(Optimizer): + ### "adapted" from https://github.com/iShohei220/adopt/blob/main/adopt.py + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.9999), + eps: float = 1e-6, + weight_decay: float = 0.0, + decoupled: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if isinstance(lr, Tensor): + if foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for `capturable=False` and `foreach=True`" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + decoupled=decoupled, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + # TODO: support fused + raise RuntimeError("`fused` is not currently supported") + + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + # TODO(crcrpar): [low prec params & their higher prec copy] + # Support AMP with FP16/BF16 model params which would need + # higher prec copy of params to do update math in higher prec to + # alleviate the loss of information. + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("ADOPT does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + if group["fused"]: + _device_dtype_check_for_fused(p) + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if ( + group["foreach"] + and torch.is_tensor(group["lr"]) + and not group["capturable"] + ): + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = group["betas"] + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ) + + adopt( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + decoupled=group["decoupled"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def _single_tensor_adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + if decoupled: + param.add_(param, alpha=-lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + if exp_avg is not None: + exp_avg = torch.view_as_real(exp_avg) + if exp_avg_sq is not None: + exp_avg_sq = torch.view_as_real(exp_avg_sq) + param = torch.view_as_real(param) + + step = step_t if capturable or differentiable else _get_value(step_t) + if step == 1: + exp_avg_sq.addcmul_(grad, grad.conj()) + continue + + denom = torch.clamp(exp_avg_sq.sqrt(), eps) + if step == 2: + exp_avg.addcdiv_(grad, denom) + else: + exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + + param.add_(exp_avg, alpha=-lr) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + + +def _multi_tensor_adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert grad_scale is None and found_inf is None + + assert not differentiable, "_foreach ops don't support autograd" + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + # Handle complex parameters + if has_complex: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + if decoupled: + torch._foreach_add_( + device_params, device_params, alpha=-lr * weight_decay + ) + else: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if device_state_steps[0] == 1: + torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) + continue + + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) + + if device_state_steps[0] == 2: + torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) + else: + torch._foreach_mul_(device_exp_avgs, beta1) + torch._foreach_addcdiv_( + device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1 + ) + + torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) +def adopt( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + decoupled: bool, + eps: float, + maximize: bool, +): + r"""Functional API that performs ADOPT algorithm computation.""" + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adopt + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adopt + else: + func = _single_tensor_adopt + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + decoupled=decoupled, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def adopt_wrapper(parameters: dict) -> torch.optim.Optimizer: + """ + Creates an AdEMAMix optimizer from the PyTorch `torch.optim` module using the input parameters. + + Args: + parameters (dict): A dictionary containing the input parameters for the optimizer. + + Returns: + torch.optim.Optimizer: An AdEMAMix optimizer. + """ + + return ADOPT( + params=parameters["model_parameters"], + lr=parameters.get("learning_rate", 1e-3), + betas=parameters.get("betas", (0.9, 0.9999)), + eps=parameters.get("eps", 1e-6), + weight_decay=parameters.get("weight_decay", 0.0), + decoupled=parameters["optimizer"].get("decoupled", False), + foreach=parameters.get("foreach", None), + maximize=parameters.get("maximize", False), + capturable=parameters.get("capturable", False), + differentiable=parameters.get("differentiable", False), + fused=parameters.get("fused", None), + ) diff --git a/GANDLF/optimizers/thirdparty/lion.py b/GANDLF/optimizers/thirdparty/lion.py new file mode 100644 index 000000000..0a6116e21 --- /dev/null +++ b/GANDLF/optimizers/thirdparty/lion.py @@ -0,0 +1,24 @@ +from torch.optim.optimizer import Optimizer +from lion_pytorch import Lion + + +def lion_wrapper(parameters: dict) -> Optimizer: + """ + Creates an instance of the Lion optimizer from the `lion_pytorch` package using the input parameters. + + Args: + parameters (dict): A dictionary containing the input parameters for the optimizer. + + Returns: + Optimizer: An instance of the Lion optimizer. + """ + return Lion( + parameters["model_parameters"], + lr=parameters.get("learning_rate", 1e-4), + betas=parameters["optimizer"].get("betas", (0.9, 0.999)), + weight_decay=parameters["optimizer"].get("weight_decay", 0.0), + decoupled_weight_decay=parameters["optimizer"].get( + "decoupled_weight_decay", False + ), + use_triton=False, # as of 20241120, triton is not generally available for all platforms + ) diff --git a/GANDLF/optimizers/wrap_monai.py b/GANDLF/optimizers/wrap_monai.py index 221ba57bd..cdc53515e 100644 --- a/GANDLF/optimizers/wrap_monai.py +++ b/GANDLF/optimizers/wrap_monai.py @@ -1,11 +1,21 @@ -import monai +from torch.optim import Optimizer + from monai.optimizers import Novograd -def novograd_wrapper(parameters: dict) -> monai.optimizers.Novograd: +def novograd_wrapper(parameters) -> Optimizer: + """ + Creates an instance of the Novograd optimizer from the `monai` package using the input parameters. + + Args: + parameters (dict): A dictionary containing the input parameters for the optimizer. + + Returns: + Optimizer: An instance of the Novograd optimizer. + """ return Novograd( parameters["model_parameters"], - lr=parameters.get("learning_rate", 1e-3), + lr=parameters.get("learning_rate"), betas=parameters["optimizer"].get("betas", (0.9, 0.999)), eps=parameters["optimizer"].get("eps", 1e-8), weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), diff --git a/GANDLF/optimizers/wrap_torch.py b/GANDLF/optimizers/wrap_torch.py index 2f4650bdb..d6deff477 100644 --- a/GANDLF/optimizers/wrap_torch.py +++ b/GANDLF/optimizers/wrap_torch.py @@ -1,5 +1,5 @@ -import torch from torch.optim import ( + Optimizer, SGD, ASGD, Rprop, @@ -15,7 +15,7 @@ ) -def sgd(parameters: dict) -> torch.optim.SGD: +def sgd(parameters) -> Optimizer: """ Creates a Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -37,7 +37,7 @@ def sgd(parameters: dict) -> torch.optim.SGD: ) -def asgd(parameters: dict) -> torch.optim.ASGD: +def asgd(parameters) -> Optimizer: """ Creates an Averaged Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -59,7 +59,7 @@ def asgd(parameters: dict) -> torch.optim.ASGD: ) -def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam: +def adam(parameters, opt_type="normal") -> Optimizer: """ Creates an Adam or AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -72,12 +72,11 @@ def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam: """ # Determine which optimizer to create based on opt_type + assert opt_type in ["normal", "AdamW"], f"Invalid optimizer type: {opt_type}" + optimizer_fn = AdamW + if opt_type == "normal": optimizer_fn = Adam - elif opt_type == "AdamW": - optimizer_fn = AdamW - else: - raise ValueError(f"Invalid optimizer type: {opt_type}") # Create the optimizer using the input parameters return optimizer_fn( @@ -90,7 +89,7 @@ def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam: ) -def adamw(parameters: dict) -> torch.optim.AdamW: +def adamw(parameters) -> Optimizer: """ Creates an AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -104,7 +103,7 @@ def adamw(parameters: dict) -> torch.optim.AdamW: return adam(parameters, opt_type="AdamW") -def adamax(parameters: dict) -> torch.optim.Adamax: +def adamax(parameters) -> Optimizer: """ Creates an Adamax optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -140,7 +139,7 @@ def adamax(parameters: dict) -> torch.optim.Adamax: # ) -def rprop(parameters: dict) -> torch.optim.Rprop: +def rprop(parameters) -> Optimizer: """ Creates a Resilient Backpropagation optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -160,7 +159,7 @@ def rprop(parameters: dict) -> torch.optim.Rprop: ) -def adadelta(parameters: dict) -> torch.optim.Adadelta: +def adadelta(parameters) -> Optimizer: """ Creates an Adadelta optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -181,7 +180,7 @@ def adadelta(parameters: dict) -> torch.optim.Adadelta: ) -def adagrad(parameters: dict) -> torch.optim.Adagrad: +def adagrad(parameters) -> Optimizer: """ Creates an Adagrad optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -203,7 +202,7 @@ def adagrad(parameters: dict) -> torch.optim.Adagrad: ) -def rmsprop(parameters: dict) -> torch.optim.RMSprop: +def rmsprop(parameters) -> Optimizer: """ Creates an RMSprop optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -226,7 +225,7 @@ def rmsprop(parameters: dict) -> torch.optim.RMSprop: ) -def radam(parameters: dict) -> torch.optim.RAdam: +def radam(parameters) -> Optimizer: """ Creates a RAdam optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -247,7 +246,7 @@ def radam(parameters: dict) -> torch.optim.RAdam: ) -def nadam(parameters: dict) -> torch.optim.NAdam: +def nadam(parameters) -> Optimizer: """ Creates a NAdam optimizer from the PyTorch `torch.optim` module using the input parameters. diff --git a/GANDLF/utils/gandlf_logging.py b/GANDLF/utils/gandlf_logging.py index 576df868e..c43d7679d 100644 --- a/GANDLF/utils/gandlf_logging.py +++ b/GANDLF/utils/gandlf_logging.py @@ -8,7 +8,7 @@ def _create_tmp_log_file(): - tmp_dir = Path(tempfile.gettempdir()) + tmp_dir = Path(Path.home()) log_dir = Path.joinpath(tmp_dir, ".gandlf") log_dir.mkdir(parents=True, exist_ok=True) log_file = Path.joinpath(log_dir, get_unique_timestamp() + ".log") diff --git a/GANDLF/version.py b/GANDLF/version.py index 969da2b13..5e5047feb 100644 --- a/GANDLF/version.py +++ b/GANDLF/version.py @@ -2,4 +2,4 @@ # -*- coding: UTF-8 -*- # check GaNDLF wiki for versioning and release guidelines: https://github.com/mlcommons/GaNDLF/wiki -__version__ = "0.1.1-dev" +__version__ = "0.1.2-dev" diff --git a/README.md b/README.md index a267fa424..89c69c375 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# GaNDLF +![GitHub-Mark-Light](https://github.com/mlcommons/GaNDLF/blob/master/docs/images/logo/full.png?raw=true#gh-light-mode-only) + +![GitHub-Mark-Dark](https://github.com/mlcommons/GaNDLF/blob/master/docs/images/logo/full_black.png?raw=true#gh-dark-mode-only)

diff --git a/docs/faq.md b/docs/faq.md index dfe0c2234..0bb98239d 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -53,6 +53,10 @@ Please see https://mlcommons.github.io/GaNDLF/usage/#federating-your-model-evalu Please read the [migration guide](https://mlcommons.github.io/GaNDLF/migration_guide) to understand the changes that have been made to GaNDLF. If you have any questions, please feel free to [post a support request](https://github.com/mlcommons/GaNDLF/issues/new?assignees=&labels=&template=--questions-help-support.md&title=). +### I am getting an error realted to version mismatch (greater or smaller) between the configuration and GaNDLF version. What should I do? + +This is a safety feature to ensure a tight integartion between the configuration used to define a model and the code version used to perform the training. Ensure that you have all requirements satisfied, and then check the ``version`` key in the configration, and ensure it appropriately matches the output of ``gandlf run --version``. + ### What if I have another question? Please [post a support request](https://github.com/mlcommons/GaNDLF/issues/new?assignees=&labels=&template=--questions-help-support.md&title=). diff --git a/docs/images/logo/full.png b/docs/images/logo/full.png new file mode 100644 index 000000000..ea558f866 Binary files /dev/null and b/docs/images/logo/full.png differ diff --git a/docs/images/logo/full_black.png b/docs/images/logo/full_black.png new file mode 100644 index 000000000..85fba87bd Binary files /dev/null and b/docs/images/logo/full_black.png differ diff --git a/docs/setup.md b/docs/setup.md index 9f9cb5397..c4ee2fc18 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -36,7 +36,7 @@ You may install pytorch to be compatible with CUDA, ROCm, or CPU-only. An exhaus Use one of the following depending on your needs: - CUDA 12.1 ```bash -(venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 +(venv_gandlf) $> pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu121 ``` ### Optional Dependencies diff --git a/docs/usage.md b/docs/usage.md index 24738d056..386adbde3 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -176,6 +176,14 @@ To split the data CSV into training, validation, and testing CSVs, the `gandlf s -o $output_dir # the output directory to save the split data ``` +### Using the `--log-file` parameter +By default, only the `info` and `error` logs will be **displayed** in the console and +the log file will be **saved** in `$(home)/.gandlf/.log`. + +Also, you can use the `--log-file` and provide the file that you want to save the logs +```bash +(venv_gandlf) $> gandlf --log-file +``` ## Customize the Training diff --git a/pyproject.toml b/pyproject.toml index 67a73e9dd..c65e14363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,4 +16,5 @@ omit = [ "./setup.py", "./testing/conftest.py", "./tutorials/*", + "./GANDLF/optimizers/thirdparty/*", ] diff --git a/samples/config_all_options.yaml b/samples/config_all_options.yaml index 51edecb01..872d65c44 100644 --- a/samples/config_all_options.yaml +++ b/samples/config_all_options.yaml @@ -1,8 +1,8 @@ # affix version version: { - minimum: 0.1.1-dev, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + minimum: 0.1.2-dev, + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } ## Choose the model parameters here model: diff --git a/samples/config_classification.yaml b/samples/config_classification.yaml index 8dbd3a082..e8b720520 100644 --- a/samples/config_classification.yaml +++ b/samples/config_classification.yaml @@ -1,8 +1,8 @@ # affix version version: { - minimum: 0.1.1-dev, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + minimum: 0.1.2-dev, + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } # Choose the model parameters here model: diff --git a/samples/config_getting_started_classification_histo2d.yaml b/samples/config_getting_started_classification_histo2d.yaml index 5d59c3c3a..f824fbd92 100644 --- a/samples/config_getting_started_classification_histo2d.yaml +++ b/samples/config_getting_started_classification_histo2d.yaml @@ -94,6 +94,6 @@ scheduler: track_memory_usage: false verbose: false version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: true diff --git a/samples/config_getting_started_classification_rad3d.yaml b/samples/config_getting_started_classification_rad3d.yaml index 83c5ef6b8..109e001e6 100644 --- a/samples/config_getting_started_classification_rad3d.yaml +++ b/samples/config_getting_started_classification_rad3d.yaml @@ -99,6 +99,6 @@ scheduler: track_memory_usage: false verbose: false version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: true diff --git a/samples/config_getting_started_regression_histo2d.yaml b/samples/config_getting_started_regression_histo2d.yaml index 0101d2ccc..fa2a41e2f 100644 --- a/samples/config_getting_started_regression_histo2d.yaml +++ b/samples/config_getting_started_regression_histo2d.yaml @@ -59,6 +59,6 @@ scheduler: track_memory_usage: false verbose: false version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: true diff --git a/samples/config_getting_started_regression_rad3d.yaml b/samples/config_getting_started_regression_rad3d.yaml index 39db08ad4..8ce80e1d1 100644 --- a/samples/config_getting_started_regression_rad3d.yaml +++ b/samples/config_getting_started_regression_rad3d.yaml @@ -62,6 +62,6 @@ scheduler: track_memory_usage: false verbose: false version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: false diff --git a/samples/config_getting_started_segmentation_histo2d.yaml b/samples/config_getting_started_segmentation_histo2d.yaml index 67871e07d..13ca80436 100644 --- a/samples/config_getting_started_segmentation_histo2d.yaml +++ b/samples/config_getting_started_segmentation_histo2d.yaml @@ -66,6 +66,6 @@ scheduler: track_memory_usage: false verbose: true version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: true diff --git a/samples/config_getting_started_segmentation_rad3d.yaml b/samples/config_getting_started_segmentation_rad3d.yaml index 39316b755..758163ff6 100644 --- a/samples/config_getting_started_segmentation_rad3d.yaml +++ b/samples/config_getting_started_segmentation_rad3d.yaml @@ -89,6 +89,6 @@ scheduler: track_memory_usage: false verbose: true version: - maximum: 0.1.1-dev - minimum: 0.1.1-dev + maximum: 0.1.2-dev + minimum: 0.1.2-dev weighted_loss: true diff --git a/samples/config_regression.yaml b/samples/config_regression.yaml index a0c1d65f8..0f4b91737 100644 --- a/samples/config_regression.yaml +++ b/samples/config_regression.yaml @@ -1,8 +1,8 @@ # affix version version: { - minimum: 0.1.1-dev, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + minimum: 0.1.2-dev, + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } # Choose the model parameters here model: diff --git a/samples/config_segmentation_brats.yaml b/samples/config_segmentation_brats.yaml index d56a30e2f..c8a5ac005 100644 --- a/samples/config_segmentation_brats.yaml +++ b/samples/config_segmentation_brats.yaml @@ -1,8 +1,8 @@ # affix version version: { - minimum: 0.1.1-dev, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + minimum: 0.1.2-dev, + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } # Choose the model parameters here model: diff --git a/samples/config_segmentation_histology.yaml b/samples/config_segmentation_histology.yaml index ddc6abd99..889ee9a98 100644 --- a/samples/config_segmentation_histology.yaml +++ b/samples/config_segmentation_histology.yaml @@ -1,8 +1,8 @@ # affix version version: { - minimum: 0.1.1-dev, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + minimum: 0.1.2-dev, + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } # Choose the model parameters here model: diff --git a/setup.py b/setup.py index 564df3045..bd75e5ae9 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ import sys, re, os from setuptools import setup, find_packages + try: with open("README.md") as readme_file: readme = readme_file.read() @@ -13,7 +14,6 @@ readme = "No README information found." sys.stderr.write("Warning: Could not open '%s' due %s\n" % ("README.md", error)) - try: filepath = "GANDLF/version.py" version_file = open(filepath) @@ -37,7 +37,7 @@ # specifying version for `black` separately because it is also used to [check for lint](https://github.com/mlcommons/GaNDLF/blob/master/.github/workflows/black.yml) black_version = "23.11.0" requirements = [ - "torch==2.3.1", + "torch==2.5.0", f"black=={black_version}", "numpy==1.25.0", "scipy", @@ -52,7 +52,6 @@ "setuptools", "seaborn", "pyyaml==6.0.1", - "tiffslide", "matplotlib", "gdown==5.1.0", "pytest", @@ -83,6 +82,9 @@ "colorlog", "opacus==1.5.2", "huggingface-hub==0.25.1", + "openslide-bin", + "openslide-python==1.4.1", + "lion-pytorch==0.2.2", ] if __name__ == "__main__": diff --git a/testing/config_classification.yaml b/testing/config_classification.yaml index ea8f0f890..79dfb5feb 100644 --- a/testing/config_classification.yaml +++ b/testing/config_classification.yaml @@ -55,7 +55,7 @@ save_output: false scaling_factor: 1 scheduler: triangle version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: True diff --git a/testing/config_regression.yaml b/testing/config_regression.yaml index 474ac163e..47b9e2aab 100644 --- a/testing/config_regression.yaml +++ b/testing/config_regression.yaml @@ -38,7 +38,7 @@ save_output: false scaling_factor: 1 scheduler: triangle version: - maximum: 0.1.1-dev + maximum: 0.1.2-dev minimum: 0.0.14 weighted_loss: false diff --git a/testing/config_segmentation.yaml b/testing/config_segmentation.yaml index e49d4b716..a275a6b8d 100644 --- a/testing/config_segmentation.yaml +++ b/testing/config_segmentation.yaml @@ -3,7 +3,7 @@ version: { minimum: 0.0.14, - maximum: 0.1.1-dev + maximum: 0.1.2-dev } model: { diff --git a/tutorials/classification_medmnist_notebook/config.yaml b/tutorials/classification_medmnist_notebook/config.yaml index 5e7e1138c..f1035dc7d 100644 --- a/tutorials/classification_medmnist_notebook/config.yaml +++ b/tutorials/classification_medmnist_notebook/config.yaml @@ -2,7 +2,7 @@ version: { minimum: 0.0.14, - maximum: 0.1.1-dev # this should NOT be made a variable, but should be tested after every tag is created + maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created } # Choose the model parameters here model: