diff --git a/docs/user_guide/12_evaluation.md b/docs/user_guide/12_evaluation.md index b060bbee..e24b4e3c 100644 --- a/docs/user_guide/12_evaluation.md +++ b/docs/user_guide/12_evaluation.md @@ -28,6 +28,10 @@ There is an additional difference between ecological object detection methods li DeepForest uses the [hungarian matching algorithm](https://thinkautonomous.medium.com/computer-vision-for-tracking-8220759eee85) to assign predictions to ground truth based on maximum IoU overlap. This is slow compared to the methods above, and so isn't a good choice for running hundreds of times during model training see config["validation"]["val_accuracy_interval"] for setting the frequency of the evaluate callback for this metric. +### Empty Frame Accuracy + +DeepForest allows the user to pass empty frames to evaluation by setting xmin, ymin, xmax, ymax to 0. This is useful for evaluating models on data that has empty frames. The empty frame accuracy is the proportion of empty frames that are contain no predictions. The 'label' column in this case is ignored, but must be one of the labels in the model to be included in the evaluation. + # Calculating Evaluation Metrics ## Torchmetrics and loss scores diff --git a/src/deepforest/main.py b/src/deepforest/main.py index f2f1695d..8e55ed92 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -4,7 +4,6 @@ import typing import warnings -import geopandas as gpd import numpy as np import pandas as pd import pytorch_lightning as pl @@ -14,15 +13,31 @@ from pytorch_lightning.callbacks import LearningRateMonitor from torch import optim from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision +from torchmetrics.classification import BinaryAccuracy + from huggingface_hub import PyTorchModelHubMixin from deepforest import dataset, visualize, get_data, utilities, predict from deepforest import evaluate as evaluate_iou from huggingface_hub import PyTorchModelHubMixin +from lightning_fabric.utilities.exceptions import MisconfigurationException class deepforest(pl.LightningModule, PyTorchModelHubMixin): - """Class for training and predicting tree crowns in RGB images.""" + """Class for training and predicting tree crowns in RGB images. + + Args: + num_classes (int): number of classes in the model + config_file (str): path to deepforest config file + model (model.Model()): a deepforest model object, see model.Model() + config_args (dict): a dictionary of key->value to update config file at run time. + e.g. {"batch_size":10}. This is useful for iterating over arguments during model testing. + existing_train_dataloader: a Pytorch dataloader that yields a tuple path, images, targets + existing_val_dataloader: a Pytorch dataloader that yields a tuple path, images, targets + + Returns: + self: a deepforest pytorch lightning module + """ def __init__(self, num_classes: int = 1, @@ -33,18 +48,7 @@ def __init__(self, model=None, existing_train_dataloader=None, existing_val_dataloader=None): - """Args: - num_classes (int): number of classes in the model - config_file (str): path to deepforest config file - model (model.Model()): a deepforest model object, see model.Model(). - config_args (dict): a dictionary of key->value to update - config file at run time. e.g. {"batch_size":10} - This is useful for iterating over arguments during model testing. - existing_train_dataloader: a Pytorch dataloader that yields a tuple path, images, targets - existing_val_dataloader: a Pytorch dataloader that yields a tuple path, images, targets - Returns: - self: a deepforest pytorch lightning module - """ + super().__init__() # Read config file. Defaults to deepforest_config.yml in working directory. @@ -97,6 +101,9 @@ def __init__(self, class_metrics=True, iou_threshold=self.config["validation"]["iou_threshold"]) self.mAP_metric = MeanAveragePrecision() + # Empty frame accuracy + self.empty_frame_accuracy = BinaryAccuracy() + # Create a default trainer. self.create_trainer() @@ -186,12 +193,8 @@ def create_model(self): models/, as is a subclass of model.Model(). The config args in the .yaml are specified. - >>> # retinanet: - >>> # ms_thresh: 0.1 - >>> # score_thresh: 0.2 - >>> # RCNN: - >>> # nms_thresh: 0.1 - >>> # etc. + Returns: + None """ if self.model is None: model_name = importlib.import_module("deepforest.models.{}".format( @@ -202,7 +205,12 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs): """Create a pytorch lightning training by reading config files. Args: + logger: A pytorch lightning logger callbacks (list): a list of pytorch-lightning callback classes + **kwargs: Additional arguments to pass to the trainer + + Returns: + None """ # If val data is passed, monitor learning rate and setup classification metrics if not self.config["validation"]["csv_file"] is None: @@ -351,14 +359,15 @@ def predict_image(self, thickness: int = 1): """Predict a single image with a deepforest model. - Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments are deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead. + Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments + are deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead. Args: image: a float32 numpy array of a RGB with channels last format path: optional path to read image from disk instead of passing image arg - (deprecated) return_plot: return a plot of the image with predictions overlaid - (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255) - (deprectaed) thickness: thickness of the rectangle border line in px + return_plot: return a plot of the image with predictions overlaid (deprecated) + color: color of the bounding box as a tuple of BGR color (deprecated) + thickness: thickness of the rectangle border line in px (deprecated) Returns: result: A pandas dataframe of predictions (Default) @@ -484,41 +493,23 @@ def predict_tile(self, Args: raster_path: Path to image on disk - image (array): Numpy image array in BGR channel order - following openCV convention - patch_size: patch size for each window. - patch_overlap: patch overlap among windows. - iou_threshold: Minimum iou overlap among predictions between - windows to be suppressed. - Lower values suppress more boxes at edges. - in_memory: If true, the entire dataset is loaded into memory, which increases speed. This is useful for small datasets, but not recommended for very large datasets. + image (array): Numpy image array in BGR channel order following openCV convention + patch_size: patch size for each window + patch_overlap: patch overlap among windows + iou_threshold: Minimum iou overlap among predictions between windows to be suppressed + in_memory: If true, the entire dataset is loaded into memory mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False) sigma: variance of Gaussian function used in Gaussian Soft NMS thresh: the score thresh used to filter bboxes after soft-nms performed - cropModel: a deepforest.model.CropModel object to predict on crops + crop_model: a deepforest.model.CropModel object to predict on crops crop_transform: a torchvision.transforms object to apply to crops crop_augment: a boolean to apply augmentations to crops - (deprecated) return_plot: return a plot of the image with predictions overlaid - (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255) - (deprecated) thickness: thickness of the rectangle border line in px - - Deprecated Args: - - return_plot: Deprecated in favor of using `visualize.plot_results` for - rendering predictions. Will be removed in version 2.0. - - color: Deprecated bounding box color for visualizations. - - thickness: Deprecated bounding box thickness for visualizations. - - Raises: - - ValueError: If `raster_path` is None when `in_memory=False`. - - ValueError: If `workers` is greater than 0 when `in_memory=False`. Multiprocessing is not supported when using out-of-memory datasets, rasterio is not threadsafe. + return_plot: return a plot of the image with predictions overlaid (deprecated) + color: color of the bounding box as a tuple of BGR color (deprecated) + thickness: thickness of the rectangle border line in px (deprecated) Returns: - - If `return_plot` is True, returns an image with predictions overlaid (deprecated). - - If `mosaic` is True, returns a Pandas DataFrame containing the predicted - bounding boxes, scores, and labels. - - If `mosaic` is False, returns a list of tuples where each tuple contains - a DataFrame of predictions and its corresponding image crop. - - Returns None if no predictions are made. + pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple """ self.model.eval() self.model.nms_thresh = self.config["nms_thresh"] @@ -642,7 +633,6 @@ def training_step(self, batch, batch_idx): # allow for empty data if data augmentation is generated path, images, targets = batch - loss_dict = self.model.forward(images, targets) # sum of regression and classification loss @@ -665,7 +655,7 @@ def validation_step(self, batch, batch_idx): print("Empty batch encountered, skipping") return None - # Get loss from "train" mode, but don't allow optimization + # Get loss from "train" mode, but don't allow optimization. Torchvision has a 'train' mode that returns a loss and a 'eval' mode that returns predictions. The names are confusing, but this is the correct way to get the loss. self.model.train() with torch.no_grad(): loss_dict = self.model.forward(images, targets) @@ -674,6 +664,7 @@ def validation_step(self, batch, batch_idx): losses = sum([loss for loss in loss_dict.values()]) self.model.eval() + # Can we avoid another forward pass here? https://discuss.pytorch.org/t/how-to-get-losses-and-predictions-at-the-same-time/167223 preds = self.model.forward(images) # Calculate intersection-over-union @@ -682,31 +673,97 @@ def validation_step(self, batch, batch_idx): # Log loss for key, value in loss_dict.items(): - self.log("val_{}".format(key), value, on_epoch=True) + try: + self.log("val_{}".format(key), value, on_epoch=True) + except MisconfigurationException: + pass for index, result in enumerate(preds): # Skip empty predictions if result["boxes"].shape[0] == 0: - continue - boxes = visualize.format_geometry(result) - boxes["image_path"] = path[index] - self.predictions.append(boxes) + self.predictions.append( + pd.DataFrame({ + "image_path": [path[index]], + "xmin": [None], + "ymin": [None], + "xmax": [None], + "ymax": [None], + "label": [None], + "score": [None] + })) + else: + boxes = visualize.format_geometry(result) + boxes["image_path"] = path[index] + self.predictions.append(boxes) return losses def on_validation_epoch_start(self): self.predictions = [] + def calculate_empty_frame_accuracy(self, ground_df, predictions_df): + """Calculate accuracy for empty frames (frames with no objects). + + Args: + ground_df (pd.DataFrame): Ground truth dataframe containing image paths and bounding boxes. + Must have columns 'image_path', 'xmin', 'ymin', 'xmax', 'ymax'. + predictions_df (pd.DataFrame): Model predictions dataframe containing image paths and predicted boxes. + Must have column 'image_path'. + + Returns: + float or None: Accuracy score for empty frame detection. A score of 1.0 means the model correctly + identified all empty frames (no false positives), while 0.0 means it predicted objects + in all empty frames (all false positives). Returns None if there are no empty frames. + """ + # Find images that are marked as empty in ground truth (all coordinates are 0) + empty_images = ground_df.loc[(ground_df.xmin == 0) & (ground_df.ymin == 0) & + (ground_df.xmax == 0) & (ground_df.ymax == 0), + "image_path"].unique() + + if len(empty_images) == 0: + return None + + # Get non-empty predictions for empty images + non_empty_predictions = predictions_df.loc[predictions_df.xmin.notnull()] + predictions_for_empty_images = non_empty_predictions.loc[ + non_empty_predictions.image_path.isin(empty_images)] + + # Create prediction tensor - 1 if model predicted objects, 0 if predicted empty + predictions = torch.zeros(len(empty_images)) + for index, image in enumerate(empty_images): + if len(predictions_for_empty_images.loc[ + predictions_for_empty_images.image_path == image]) > 0: + predictions[index] = 1 + + # Ground truth tensor - all zeros since these are empty frames + gt = torch.zeros(len(empty_images)) + predictions = torch.tensor(predictions) + + # Calculate accuracy using metric + self.empty_frame_accuracy.update(predictions, gt) + empty_accuracy = self.empty_frame_accuracy.compute() + + return empty_accuracy + def on_validation_epoch_end(self): + """Compute metrics.""" + output = self.iou_metric.compute() - self.log_dict(output) - self.iou_metric.reset() + try: + # This is a bug in lightning, it claims this is a warning but it is not. https://github.com/Lightning-AI/pytorch-lightning/pull/9733/files + self.log_dict(output) + except: + pass + self.iou_metric.reset() output = self.mAP_metric.compute() # Remove classes from output dict output = {key: value for key, value in output.items() if not key == "classes"} - self.log_dict(output) + try: + self.log_dict(output) + except MisconfigurationException: + pass self.mAP_metric.reset() if len(self.predictions) == 0: @@ -720,8 +777,21 @@ def on_validation_epoch_end(self): ground_df = utilities.read_file(self.config["validation"]["csv_file"]) ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) + # If there are empty frames, evaluate empty frame accuracy separately + empty_accuracy = self.calculate_empty_frame_accuracy( + ground_df, self.predictions_df) + + if empty_accuracy is not None: + try: + self.log("empty_frame_accuracy", empty_accuracy) + except: + pass + + # Remove empty predictions from the rest of the evaluation + self.predictions_df = self.predictions_df.loc[ + self.predictions_df.xmin.notnull()] if self.predictions_df.empty: - warnings.warn("No predictions made, skipping evaluation") + warnings.warn("No predictions made, skipping detection evaluation") geom_type = utilities.determine_geometry_type(ground_df) if geom_type == "box": result = { @@ -730,6 +800,13 @@ def on_validation_epoch_end(self): "class_recall": pd.DataFrame() } else: + # Remove empty ground truth + ground_df = ground_df.loc[~(ground_df.xmin == 0)] + if ground_df.empty: + results = {} + results["empty_frame_accuracy"] = empty_accuracy + return results + results = evaluate_iou.__evaluate_wrapper__( predictions=self.predictions_df, ground_df=ground_df, @@ -738,23 +815,33 @@ def on_validation_epoch_end(self): savedir=None, numeric_to_label_dict=self.numeric_to_label_dict) + if empty_accuracy is not None: + results["empty_frame_accuracy"] = empty_accuracy + # Log each key value pair of the results dict - for key, value in results.items(): - if key in ["class_recall"]: - for index, row in value.iterrows(): - self.log( - "{}_Recall".format( - self.numeric_to_label_dict[row["label"]]), - row["recall"]) - self.log( - "{}_Precision".format( - self.numeric_to_label_dict[row["label"]]), - row["precision"]) - else: - try: - self.log(key, value) - except: + if not results["class_recall"] is None: + for key, value in results.items(): + if key in ["class_recall"]: + for index, row in value.iterrows(): + try: + self.log( + "{}_Recall".format( + self.numeric_to_label_dict[row["label"]]), + row["recall"]) + self.log( + "{}_Precision".format( + self.numeric_to_label_dict[row["label"]]), + row["precision"]) + except MisconfigurationException: + pass + elif key in ["predictions", "results"]: + # Don't log dataframes of predictions or IoU results per epoch pass + else: + try: + self.log(key, value) + except MisconfigurationException: + pass def predict_step(self, batch, batch_idx): batch_results = self.model(batch) @@ -863,12 +950,13 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): iou_threshold. Args: - csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label", each box in a row - root_dir: location of files in the dataframe 'name' column. - iou_threshold: float [0,1] intersection-over-union union between annotation and prediction to be scored true positive + csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" + root_dir: location of files in the dataframe 'name' column + iou_threshold: float [0,1] intersection-over-union threshold for true positive savedir: location to save images with bounding boxes + Returns: - results: dict of ("results", "precision", "recall") for a given threshold + dict: Results dictionary containing precision, recall and other metrics """ ground_df = utilities.read_file(csv_file) ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) diff --git a/tests/test_main.py b/tests/test_main.py index 13d10826..c234aa83 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -15,6 +15,8 @@ from albumentations.pytorch import ToTensorV2 from deepforest import main, get_data, dataset, model +from deepforest.visualize import format_geometry + from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import TensorBoardLogger @@ -165,6 +167,28 @@ def test_train_empty(m, tmpdir): def test_validation_step(m): + val_dataloader = m.val_dataloader() + batch = next(iter(val_dataloader)) + m.predictions = [] + val_loss = m.validation_step(batch, 0) + assert val_loss != 0 + +def test_validation_step_empty(): + """If the model returns an empty prediction, the metrics should not fail""" + m = main.deepforest() + m.config["validation"]["csv_file"] = get_data("example.csv") + m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) + m.create_trainer() + + val_dataloader = m.val_dataloader() + batch = next(iter(val_dataloader)) + m.predictions = [] + val_loss = m.validation_step(batch, 0) + assert len(m.predictions) == 1 + assert m.predictions[0].xmin.isna().all() + assert m.iou_metric.compute()["iou"] == 0 + +def test_validate(m): m.trainer = None # Turn off trainer to test copying on some linux devices. before = copy.deepcopy(m) @@ -677,7 +701,6 @@ def test_predict_tile_with_crop_model_empty(): # Assert the result assert result is None - # @pytest.mark.parametrize("batch_size", [1, 4, 8]) # def test_batch_prediction(m, batch_size, raster_path): # @@ -779,3 +802,108 @@ def test_predict_tile_with_crop_model_empty(): # "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" # } # assert not batch_pred.empty + +def test_epoch_evaluation_end(m): + preds = [{ + 'boxes': torch.tensor([ + [690.3572, 902.9113, 781.1031, 996.5151], + [998.1990, 655.7919, 172.4619, 321.8518] + ]), + 'scores': torch.tensor([ + 0.6740, 0.6625 + ]), + 'labels': torch.tensor([ + 0, 0 + ]) + }] + targets = preds + + m.iou_metric.update(preds, targets) + m.mAP_metric.update(preds, targets) + + boxes = format_geometry(preds[0]) + boxes["image_path"] = "test" + m.predictions = [boxes] + m.on_validation_epoch_end() + +def test_epoch_evaluation_end_empty(m): + """If the model returns an empty prediction, the metrics should not fail""" + preds = [{ + 'boxes': torch.zeros((1, 4)), + 'scores': torch.zeros(1), + 'labels': torch.zeros(1, dtype=torch.int64) + }] + targets = preds + + m.iou_metric.update(preds, targets) + m.mAP_metric.update(preds, targets) + + boxes = format_geometry(preds[0]) + boxes["image_path"] = "test" + m.predictions = [boxes] + m.on_validation_epoch_end() + +def test_empty_frame_accuracy_with_predictions(m, tmpdir): + """Create a ground truth with empty frames, the accuracy should be 1 with a random model""" + # Create ground truth with empty frames + ground_df = pd.read_csv(get_data("testfile_deepforest.csv")) + # Set all xmin, ymin, xmax, ymax to 0 + ground_df.loc[:, ["xmin", "ymin", "xmax", "ymax"]] = 0 + ground_df.drop_duplicates(subset=["image_path"], keep="first", inplace=True) + + # Save the ground truth to a temporary file + ground_df.to_csv(tmpdir.strpath + "/ground_truth.csv", index=False) + m.config["validation"]["csv_file"] = tmpdir.strpath + "/ground_truth.csv" + m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_deepforest.csv")) + + m.create_trainer() + results = m.trainer.validate(m) + assert results[0]["empty_frame_accuracy"] == 0 + +def test_empty_frame_accuracy_without_predictions(tmpdir): + """Create a ground truth with empty frames, the accuracy should be 1 with a random model""" + m = main.deepforest() + # Create ground truth with empty frames + ground_df = pd.read_csv(get_data("testfile_deepforest.csv")) + # Set all xmin, ymin, xmax, ymax to 0 + ground_df.loc[:, ["xmin", "ymin", "xmax", "ymax"]] = 0 + ground_df.drop_duplicates(subset=["image_path"], keep="first", inplace=True) + + # Save the ground truth to a temporary file + ground_df.to_csv(tmpdir.strpath + "/ground_truth.csv", index=False) + m.config["validation"]["csv_file"] = tmpdir.strpath + "/ground_truth.csv" + m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_deepforest.csv")) + + m.create_trainer() + results = m.trainer.validate(m) + assert results[0]["empty_frame_accuracy"] == 1 + +def test_mulit_class_with_empty_frame_accuracy_without_predictions(two_class_m, tmpdir): + """Create a ground truth with empty frames, the accuracy should be 1 with a random model""" + # Create ground truth with empty frames + ground_df = pd.read_csv(get_data("testfile_deepforest.csv")) + # Set all xmin, ymin, xmax, ymax to 0 + ground_df.loc[:, ["xmin", "ymin", "xmax", "ymax"]] = 0 + ground_df.drop_duplicates(subset=["image_path"], keep="first", inplace=True) + ground_df.loc[:, "label"] = "Alive" + + # Merge with a multi class ground truth + multi_class_df = pd.read_csv(get_data("testfile_multi.csv")) + ground_df = pd.concat([ground_df, multi_class_df]) + + # Save the ground truth to a temporary file + ground_df.to_csv(tmpdir.strpath + "/ground_truth.csv", index=False) + two_class_m.config["validation"]["csv_file"] = tmpdir.strpath + "/ground_truth.csv" + two_class_m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_deepforest.csv")) + + two_class_m.create_trainer() + results = two_class_m.trainer.validate(two_class_m) + assert results[0]["empty_frame_accuracy"] == 1 + +def test_evaluate_on_epoch_interval(m): + m.config["validation"]["val_accuracy_interval"] = 1 + m.config["train"]["epochs"] = 1 + m.create_trainer() + m.trainer.fit(m) + assert m.trainer.logged_metrics["box_precision"] + assert m.trainer.logged_metrics["box_recall"] \ No newline at end of file