Skip to content

Commit

Permalink
Merge pull request #10 from szmazurek/feature/torchmetrics_upgrade
Browse files Browse the repository at this point in the history
Feature/torchmetrics upgrade
  • Loading branch information
szmazurek authored Nov 19, 2023
2 parents dabc228 + baf5153 commit 1c962d7
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 175 deletions.
81 changes: 61 additions & 20 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torchmetrics as tm
from torch.nn.functional import one_hot
from ..utils import get_output_from_calculator
from GANDLF.utils.generic import determine_task


def overall_stats(predictions, ground_truth, params):
Expand All @@ -26,42 +28,81 @@ def overall_stats(predictions, ground_truth, params):
"per_class_average": "macro",
"per_class_weighted": "weighted",
}
task = determine_task(params)
# consider adding a "multilabel field in the future"
# metrics that need the "average" parameter

for average_type, average_type_key in average_types_keys.items():
calculators = {
"accuracy": tm.Accuracy(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"precision": tm.Precision(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"recall": tm.Recall(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"f1": tm.F1Score(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"specificity": tm.Specificity(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
## weird error for multi-class problem, where pos_label is not getting set
# "aucroc": tm.AUROC(
# num_classes=params["model"]["num_classes"], average=average_type_key
# ),
"aucroc": tm.AUROC(
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key
if average_type_key is not "micro"
else "macro",
),
}
for metric_name, calculator in calculators.items():
output_metrics[
f"{metric_name}_{average_type}"
] = get_output_from_calculator(predictions, ground_truth, calculator)
if metric_name == "aucroc":
one_hot_preds = one_hot(
predictions.long(),
num_classes=params["model"]["num_classes"],
)
output_metrics[metric_name] = get_output_from_calculator(
one_hot_preds.float(), ground_truth, calculator
)
else:
output_metrics[metric_name] = get_output_from_calculator(
predictions, ground_truth, calculator
)

#### HERE WE NEED TO MODIFY TESTS - ROC IS RETURNING A TUPLE. WE MAY ALSO DISCRAD IT ####
# what is AUC metric telling at all? Computing it for predictions and ground truth
# is not making sense
# metrics that do not have any "average" parameter
calculators = {
"auc": tm.AUC(reorder=True),
## weird error for multi-class problem, where pos_label is not getting set
# "roc": tm.ROC(num_classes=params["model"]["num_classes"]),
}
for metric_name, calculator in calculators.items():
output_metrics[metric_name] = get_output_from_calculator(
predictions, ground_truth, calculator
)
# calculators = {
#
# # "auc": tm.AUC(reorder=True),
# ## weird error for multi-class problem, where pos_label is not getting set
# "roc": tm.ROC(task=task, num_classes=params["model"]["num_classes"]),
# }
# for metric_name, calculator in calculators.items():
# if metric_name == "roc":
# one_hot_preds = one_hot(
# predictions.long(), num_classes=params["model"]["num_classes"]
# )
# output_metrics[metric_name] = get_output_from_calculator(
# one_hot_preds.float(), ground_truth, calculator
# )
# else:
# output_metrics[metric_name] = get_output_from_calculator(
# predictions, ground_truth, calculator
# )

return output_metrics
77 changes: 62 additions & 15 deletions GANDLF/metrics/generic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
import torch
from torchmetrics import F1Score, Precision, Recall, JaccardIndex, Accuracy, Specificity
from torchmetrics import (
F1Score,
Precision,
Recall,
JaccardIndex,
Accuracy,
Specificity,
)
from GANDLF.utils.tensor import one_hot
from GANDLF.utils.generic import determine_task


def generic_function_output_with_check(predicted_classes, label, metric_function):
def define_average_type_key(params, metric_name):
"""Determine if the metric config defines the type of average to use.
If not, fallback to the default (macro) average type.
"""
if "average" in params["metrics"][metric_name]:
average_type_key = params["metrics"][metric_name]["average"]
else:
average_type_key = "macro"
UserWarning(
"WARNING: Average type not defined in config, using default (macro)."
)
return average_type_key


def generic_function_output_with_check(
predicted_classes, label, metric_function
):
if torch.min(predicted_classes) < 0:
print(
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
)
return torch.zeros((1), device=predicted_classes.device)
else:
try:
max_clamp_val = metric_function.num_classes - 1
except AttributeError:
max_clamp_val = 1
predicted_new = torch.clamp(
predicted_classes.cpu().int(), max=metric_function.num_classes - 1
predicted_classes.cpu().int(), max=max_clamp_val
)
predicted_new = predicted_new.reshape(label.shape)
return metric_function(predicted_new, label.cpu().int())


def generic_torchmetrics_score(output, label, metric_class, metric_key, params):
def generic_torchmetrics_score(
output, label, metric_class, metric_key, params
):
task = determine_task(params)
num_classes = params["model"]["num_classes"]
predicted_classes = output
if params["problem_type"] == "classification":
Expand All @@ -28,10 +59,9 @@ def generic_torchmetrics_score(output, label, metric_class, metric_key, params):
params["metrics"][metric_key]["multi_class"] = False
params["metrics"][metric_key]["mdmc_average"] = None
metric_function = metric_class(
task=task,
average=params["metrics"][metric_key]["average"],
num_classes=num_classes,
multiclass=params["metrics"][metric_key]["multi_class"],
mdmc_average=params["metrics"][metric_key]["mdmc_average"],
threshold=params["metrics"][metric_key]["threshold"],
)

Expand All @@ -45,19 +75,25 @@ def recall_score(output, label, params):


def precision_score(output, label, params):
return generic_torchmetrics_score(output, label, Precision, "precision", params)
return generic_torchmetrics_score(
output, label, Precision, "precision", params
)


def f1_score(output, label, params):
return generic_torchmetrics_score(output, label, F1Score, "f1", params)


def accuracy(output, label, params):
return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params)
return generic_torchmetrics_score(
output, label, Accuracy, "accuracy", params
)


def specificity_score(output, label, params):
return generic_torchmetrics_score(output, label, Specificity, "specificity", params)
return generic_torchmetrics_score(
output, label, Specificity, "specificity", params
)


def iou_score(output, label, params):
Expand All @@ -67,12 +103,23 @@ def iou_score(output, label, params):
predicted_classes = torch.argmax(output, 1)
elif params["problem_type"] == "segmentation":
label = one_hot(label, params["model"]["class_list"])

recall = JaccardIndex(
reduction=params["metrics"]["iou"]["reduction"],
num_classes=num_classes,
threshold=params["metrics"]["iou"]["threshold"],
)
task = determine_task(params)
if task == "binary":
recall = JaccardIndex(
task=task,
threshold=params["metrics"]["iou"]["threshold"],
)
elif task == "multiclass":
recall = JaccardIndex(
task=task,
average=define_average_type_key(params, "iou"),
num_classes=num_classes,
threshold=params["metrics"]["iou"]["threshold"],
)
else:
raise NotImplementedError(
"IoU score is not implemented for multilabel problems"
)

return generic_function_output_with_check(
predicted_classes.cpu().int(), label.cpu().int(), recall
Expand Down
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
suppress_stdout_stderr,
set_determinism,
print_and_format_metrics,
determine_task,
)

from .modelio import (
Expand Down
23 changes: 20 additions & 3 deletions GANDLF/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import SimpleITK as sitk
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from os import devnull
from typing import Dict, Any, Union


@contextmanager
Expand Down Expand Up @@ -48,6 +49,15 @@ def checkPatchDivisibility(patch_size, number=16):
return True


def determine_task(params: Dict[str, Union[Dict[str, Any], Any]]) -> str:
"""Determine the task (binary or multiclass) from the model config.
Args:
params (dict): The parameter dictionary containing training and data information.
"""
task = "binary" if params["model"]["num_classes"] == 2 else "multiclass"
return task


def get_date_time():
"""
Get a well-parsed date string
Expand Down Expand Up @@ -146,7 +156,10 @@ def checkPatchDimensions(patch_size, numlay):
patch_size_to_check = patch_size_to_check[:-1]

if all(
[x >= 2 ** (numlay + 1) and x % 2**numlay == 0 for x in patch_size_to_check]
[
x >= 2 ** (numlay + 1) and x % 2**numlay == 0
for x in patch_size_to_check
]
):
return numlay
else:
Expand Down Expand Up @@ -182,7 +195,9 @@ def get_array_from_image_or_tensor(input_tensor_or_image):
elif isinstance(input_tensor_or_image, np.ndarray):
return input_tensor_or_image
else:
raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray")
raise ValueError(
"Input must be a torch.Tensor or sitk.Image or np.ndarray"
)


def set_determinism(seed=42):
Expand Down Expand Up @@ -252,7 +267,9 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
output_metrics_dict = deepcopy(cohort_level_metrics)
for metric in metrics_dict_from_parameters:
if isinstance(sample_level_metrics[metric], np.ndarray):
to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist()
to_print = (
sample_level_metrics[metric] / length_of_dataloader
).tolist()
else:
to_print = sample_level_metrics[metric] / length_of_dataloader
output_metrics_dict[metric] = to_print
Expand Down
14 changes: 10 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
readme = readme_file.read()
except Exception as error:
readme = "No README information found."
sys.stderr.write("Warning: Could not open '%s' due %s\n" % ("README.md", error))
sys.stderr.write(
"Warning: Could not open '%s' due %s\n" % ("README.md", error)
)


class CustomInstallCommand(install):
Expand All @@ -39,7 +41,9 @@ def run(self):

except Exception as error:
__version__ = "0.0.1"
sys.stderr.write("Warning: Could not open '%s' due %s\n" % (filepath, error))
sys.stderr.write(
"Warning: Could not open '%s' due %s\n" % (filepath, error)
)

# Handle cases where specific files need to be bundled into the final package as installed via PyPI
dockerfiles = [
Expand All @@ -54,7 +58,9 @@ def run(self):
]
setup_files = ["setup.py", ".dockerignore", "pyproject.toml", "MANIFEST.in"]
all_extra_files = dockerfiles + entrypoint_files + setup_files
all_extra_files_pathcorrected = [os.path.join("../", item) for item in all_extra_files]
all_extra_files_pathcorrected = [
os.path.join("../", item) for item in all_extra_files
]
# find_packages should only ever find these as subpackages of gandlf, not as top-level packages
# generate this dynamically?
# GANDLF.GANDLF is needed to prevent recursion madness in deployments
Expand Down Expand Up @@ -99,7 +105,7 @@ def run(self):
"psutil",
"medcam",
"opencv-python",
"torchmetrics==0.8.1",
"torchmetrics==1.1.2",
"zarr==2.10.3",
"pydicom",
"onnx",
Expand Down
2 changes: 1 addition & 1 deletion testing/config_classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ metrics:
- recall
- specificity
- iou: {
reduction: sum,
average: micro,
}

modality: rad
Expand Down
Loading

0 comments on commit 1c962d7

Please sign in to comment.