diff --git a/documentation/source/ModelPredictions.md b/documentation/source/ModelPredictions.md index 784b5c2ba7..1cb36b2bd4 100644 --- a/documentation/source/ModelPredictions.md +++ b/documentation/source/ModelPredictions.md @@ -1,4 +1,4 @@ -# Using Pretrained Models for Predictions +# Using Pretrained Models for Predictions In this tutorial, we will demonstrate how to use the `model.predict()` method for object detection tasks. @@ -10,7 +10,7 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained ## Supported Media Formats -A `mode.predict()` method is built to handle multiple data formats and types. +A `mode.predict()` method is built to handle multiple data formats and types. Here is the full list of what `predict()` method can handle: | Argument Semantics | Argument Type | Supported layout | Example | Notes | @@ -25,13 +25,13 @@ Here is the full list of what `predict()` method can handle: | 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model | | 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model | -**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**. +**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**. This means that the input tensors **should not** be normalized beforehand. Here is the example of **incorrect** code of using `model.predict()`: ```python # Incorrect code example. Do not use it. -from super_gradients.training import dataloaders +from super_gradients.training import dataloaders from super_gradients.common.object_names import Models from super_gradients.training import models @@ -139,6 +139,10 @@ You can also directly access a specific image prediction by referencing its inde ## Detect Objects in Animated GIFs and Videos The processing for both gif and videos is similar, as they are treated as videos internally. You can use the same `model.predict()` method as before, but pass the path to a GIF or video file instead. The results can be saved as either a `.gif` or `.mp4`. +To mitigate Out-of-Memory (OOM) errors, the `model.predict()` method for video returns a generator object. This allows the video frames to be processed sequentially, minimizing memory usage. It's important to be aware that model inference in this mode will be slower since batching is not supported. + +Consequently, you need to invoke `model.predict()` before each `show()` and `save()` call. + ### Load an Animated GIF or Video File Let's load an animated GIF or a video file and pass it to the `model.predict()` method: @@ -170,7 +174,7 @@ media_predictions.save("output_video.mp4") # Save as .mp4 The number of Frames Per Second (FPS) at which the model processes the gif/video can be seen directly next to the loading bar when running `model.predict('my_video.mp4')`. In the following example, the FPS is 39.49it/s (i.e. fps) -`Predicting Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]` +`Processing Video: 100%|███████████████████████| 306/306 [00:07<00:00, 39.49it/s]` Note that the video/gif will be saved with original FPS (i.e. `media_predictions.fps`). @@ -237,13 +241,13 @@ predictions = model.predict(image, skip_image_resizing=True) The following images illustrate the difference in detection results with and without resizing. #### Original Image -![Original Image](images/detection_example_beach_raw_image.jpeg) +![Original Image](images/detection_example_beach_raw_image.jpeg) *This is the raw image before any processing.* #### Image Processed with Standard Resizing (640x640) -![Resized Image](images/detection_example_beach_resized_predictions.jpg) +![Resized Image](images/detection_example_beach_resized_predictions.jpg) *This image shows the detection results after resizing the image to the model's trained size of 640x640.* #### Image Processed in Original Size -![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg) +![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg) *Here, the image is processed in its original size, demonstrating how the model performs without resizing. Notice the differences in object detection and details compared to the resized version.* diff --git a/src/super_gradients/examples/predict/detection_predict_video.py b/src/super_gradients/examples/predict/detection_predict_video.py index acc935aa26..0d79b7c4ac 100644 --- a/src/super_gradients/examples/predict/detection_predict_video.py +++ b/src/super_gradients/examples/predict/detection_predict_video.py @@ -17,6 +17,10 @@ f.write(response.content) predictions = model.predict(video_path) -predictions.show() predictions.save("pose_elephant_flip_prediction.mp4") + +predictions = model.predict(video_path) predictions.save("pose_elephant_flip_prediction.gif") # Can also be saved as a gif. + +predictions = model.predict(video_path) +predictions.show() diff --git a/src/super_gradients/examples/predict/pose_estimation_predict_video.py b/src/super_gradients/examples/predict/pose_estimation_predict_video.py new file mode 100644 index 0000000000..a30fab8379 --- /dev/null +++ b/src/super_gradients/examples/predict/pose_estimation_predict_video.py @@ -0,0 +1,25 @@ +import torch +from super_gradients.training import models + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("-p", "--path_to_video", type=str) + +if __name__ == "__main__": + args = parser.parse_args() + + # Note that currently only YoloX, PPYoloE and YOLO-NAS are supported. + model = models.get("yolo_nas_pose_l", pretrained_weights="coco_pose") + + # We want to use cuda if available to speed up inference. + model = model.to("cuda" if torch.cuda.is_available() else "cpu") + + predictions = model.predict(args.path_to_video) + predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.mp4") + + predictions = model.predict(args.path_to_video) + predictions.save(f"{args.path_to_video.split('/')[-1]}_prediction.gif") # Can also be saved as a gif. + + predictions = model.predict(args.path_to_video) + predictions.show() diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index 9653bb2acd..6bcbb8e99e 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -25,7 +25,7 @@ ClassificationPrediction, ) from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device -from super_gradients.training.utils.media.video import load_video, includes_video_extension +from super_gradients.training.utils.media.video import includes_video_extension, lazy_load_video from super_gradients.training.utils.media.image import ImageSource, check_image_typing from super_gradients.training.utils.media.stream import WebcamStreaming from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback @@ -134,9 +134,10 @@ def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> Vide :param batch_size: The size of each batch. :return: Results of the prediction. """ - video_frames, fps = load_video(file_path=video_path) + video_frames, fps, num_frames = lazy_load_video(file_path=video_path) result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size) - return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames)) + return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=num_frames) + # return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames)) def predict_webcam(self) -> None: """Predict using webcam""" @@ -335,8 +336,7 @@ def _combine_image_prediction_to_images( def _combine_image_prediction_to_video( self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None ) -> VideoDetectionPrediction: - images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")] - return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps) + return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images) class PoseEstimationPipeline(Pipeline): @@ -419,8 +419,7 @@ def _combine_image_prediction_to_images( def _combine_image_prediction_to_video( self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None ) -> VideoPoseEstimationPrediction: - images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")] - return VideoPoseEstimationPrediction(_images_prediction_lst=images_predictions, fps=fps) + return VideoPoseEstimationPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images) class ClassificationPipeline(Pipeline): diff --git a/src/super_gradients/training/utils/media/video.py b/src/super_gradients/training/utils/media/video.py index 1808ff11ff..cfaef64fc9 100644 --- a/src/super_gradients/training/utils/media/video.py +++ b/src/super_gradients/training/utils/media/video.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Iterable, Iterator import cv2 import PIL @@ -30,6 +30,23 @@ def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[n return frames, fps +def lazy_load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[Iterator[np.ndarray], int, int]: + """Open a video file and returns a generator which yields frames. + + :param file_path: Path to the video file. + :param max_frames: Optional, maximum number of frames to extract. + :return: + - Generator yielding frames representing the video, each in (H, W, C), RGB. + - Frames per Second (FPS). + - Amount of frames in video. + """ + cap = _open_video(file_path) + fps = cap.get(cv2.CAP_PROP_FPS) + frames = _lazy_extract_frames(cap, max_frames) + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + return frames, fps, num_frames + + def _open_video(file_path: str) -> cv2.VideoCapture: """Open a video file. @@ -61,6 +78,27 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> return frames +def _lazy_extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> Iterator[np.ndarray]: + """Lazy implementation of frames extraction from an opened video capture object. + NOTE: Releases the capture object. + + :param cap: Opened video capture object. + :param max_frames: Optional maximum number of frames to extract. + :return: Generator yielding frames representing the video, each in (H, W, C), RGB. + """ + frames_counter = 0 + + while frames_counter != max_frames: + frame_read_success, frame = cap.read() + if not frame_read_success: + break + + frames_counter += 1 + yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + cap.release() + + def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None: """Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file. @@ -78,64 +116,61 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None: save_mp4(output_path, frames, fps) -def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None: - """Save a video locally in .gif format. +def save_gif(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None: + """Save a video locally in .gif format. Safe for generator of frames object. :param output_path: Where the video will be saved :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape. :param fps: Frames per second """ + frame_iter_obj = iter(frames) + pil_frames_iter_obj = map(PIL.Image.fromarray, frame_iter_obj) - frames_pil = [PIL.Image.fromarray(frame) for frame in frames] + first_frame = next(pil_frames_iter_obj) - frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0) + first_frame.save(output_path, save_all=True, append_images=pil_frames_iter_obj, duration=int(1000 / fps), loop=0) -def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None: - """Save a video locally in .mp4 format. +def save_mp4(output_path: str, frames: Iterable[np.ndarray], fps: int) -> None: + """Save a video locally in .mp4 format. Safe for generator of frames object. :param output_path: Where the video will be saved :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape. :param fps: Frames per second """ - video_height, video_width = _validate_frames(frames) - - video_writer = cv2.VideoWriter( - output_path, - cv2.VideoWriter_fourcc(*"mp4v"), - fps, - (video_width, video_height), - ) + video_height, video_width, video_writer = None, None, None for frame in frames: + if video_height is None: + video_height, video_width = frame.shape[:2] + video_writer = cv2.VideoWriter( + output_path, + cv2.VideoWriter_fourcc(*"mp4v"), + fps, + (video_width, video_height), + ) + _validate_frame(frame, video_height, video_width) video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video_writer.release() -def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]: - """Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C)) +def _validate_frame(frame: np.ndarray, control_height: int, control_width: int) -> None: + """Validate the frame to make sure it has the correct size and includes the channel dimension. (i.e. (H, W, C)) - :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape. - :return: (Height, Weight) of the video. + :param frame: Single frame from the video, in (H, W, C), RGB. """ - min_height = min(frame.shape[0] for frame in frames) - max_height = max(frame.shape[0] for frame in frames) - - min_width = min(frame.shape[1] for frame in frames) - max_width = max(frame.shape[1] for frame in frames) + height, width = frame.shape[:2] - if (min_height, min_width) != (max_height, max_width): + if (height, width) != (control_height, control_width): raise RuntimeError( - f"Your video is made of frames that have (height, width) going from ({min_height}, {min_width}) to ({max_height}, {max_width}).\n" + f"Current frame has resolution {height}x{width} but {control_height}x{control_width} is expected!" f"Please make sure that all the frames have the same shape." ) - if set(frame.ndim for frame in frames) != {3} or set(frame.shape[-1] for frame in frames) != {3}: + if frame.ndim != 3: raise RuntimeError("Your frames must include 3 channels.") - return max_height, max_width - def show_video_from_disk(video_path: str, window_name: str = "Prediction"): """Display a video from disk using OpenCV. diff --git a/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py index 25d5bb1589..7643662081 100644 --- a/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py +++ b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from typing import List +from typing import List, Iterator import numpy as np @@ -9,6 +9,8 @@ from super_gradients.training.utils.media.video import show_video_from_frames, save_video from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization +from tqdm import tqdm + @dataclass class ImagePoseEstimationPrediction(ImagePrediction): @@ -210,8 +212,9 @@ class VideoPoseEstimationPrediction(VideoPredictions): :att fps: Frames per second of the video """ - _images_prediction_lst: List[ImagePoseEstimationPrediction] + _images_prediction_gen: Iterator[ImagePoseEstimationPrediction] fps: int + n_frames: int def draw( self, @@ -221,7 +224,7 @@ def draw( keypoint_radius: int = 5, box_thickness: int = 2, show_confidence: bool = False, - ) -> List[np.ndarray]: + ) -> Iterator[np.ndarray]: """Draw the predicted bboxes on the images. :param output_folder: Folder path, where the images will be saved. @@ -236,10 +239,11 @@ def draw( :param show_confidence: Whether to show confidence scores on the image. :param box_thickness: Thickness of bounding boxes. - :return: List of images with predicted bboxes. Note that this does not modify the original image. + :return: Iterator of images with predicted bboxes. Note that this does not modify the original image. """ - frames_with_bbox = [ - result.draw( + + for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"): + yield result.draw( edge_colors=edge_colors, joint_thickness=joint_thickness, keypoint_colors=keypoint_colors, @@ -247,9 +251,6 @@ def draw( box_thickness=box_thickness, show_confidence=show_confidence, ) - for result in self._images_prediction_lst - ] - return frames_with_bbox def show( self, diff --git a/src/super_gradients/training/utils/predict/prediction_results.py b/src/super_gradients/training/utils/predict/prediction_results.py index a4a987756d..0249365c56 100644 --- a/src/super_gradients/training/utils/predict/prediction_results.py +++ b/src/super_gradients/training/utils/predict/prediction_results.py @@ -16,6 +16,8 @@ from .predictions import Prediction, DetectionPrediction, ClassificationPrediction from ...datasets.data_formats.bbox_formats import convert_bboxes +from tqdm import tqdm + @dataclass class ImagePrediction(ABC): @@ -325,15 +327,16 @@ def save(self, *args, **kwargs) -> None: @dataclass -class VideoPredictions(ImagesPredictions, ABC): +class VideoPredictions(ABC): """Object wrapping the list of image predictions as a Video. - :attr _images_prediction_lst: List of results of the run + :attr _images_prediction_gen: List of results of the run :att fps: Frames per second of the video """ - _images_prediction_lst: List[ImagePrediction] + _images_prediction_gen: Iterator[ImagePrediction] fps: float + n_frames: int @abstractmethod def show(self, *args, **kwargs) -> None: @@ -504,12 +507,13 @@ def save( class VideoDetectionPrediction(VideoPredictions): """Object wrapping the list of image detection predictions as a Video. - :attr _images_prediction_lst: List of the predictions results + :attr _images_prediction_gen: Iterable object of the predictions results :att fps: Frames per second of the video """ - _images_prediction_lst: List[ImageDetectionPrediction] + _images_prediction_gen: Iterator[ImagePrediction] fps: int + n_frames: int def draw( self, @@ -517,7 +521,7 @@ def draw( show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None, class_names: Optional[List[str]] = None, - ) -> List[np.ndarray]: + ) -> Iterator[np.ndarray]: """Draw the predicted bboxes on the images. :param box_thickness: Thickness of bounding boxes. @@ -525,18 +529,16 @@ def draw( :param color_mapping: List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names. :param class_names: List of class names to show. By default, is None which shows all classes using during training. - :return: List of images with predicted bboxes. Note that this does not modify the original image. + :return: Iterable object of images with predicted bboxes. Note that this does not modify the original image. """ - frames_with_bbox = [ - result.draw( + + for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"): + yield result.draw( box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping, class_names=class_names, ) - for result in self._images_prediction_lst - ] - return frames_with_bbox def show( self, diff --git a/tests/unit_tests/test_predict.py b/tests/unit_tests/test_predict.py index 886b3f1a84..a7638d57d1 100644 --- a/tests/unit_tests/test_predict.py +++ b/tests/unit_tests/test_predict.py @@ -3,12 +3,13 @@ import tempfile from pathlib import Path -import numpy as np - from super_gradients.common.object_names import Models from super_gradients.training import models from super_gradients.training.datasets import COCODetectionDataset +import cv2 +import numpy as np + class TestModelPredict(unittest.TestCase): def setUp(self) -> None: @@ -36,6 +37,22 @@ def _set_images_with_targets(self): self.np_array_target_bboxes = [y1[:, :4], y2[:, :4]] self.np_array_target_class_ids = [y1[:, 4], y2[:, 4]] + def _prepare_video(self, path): + video_width, video_height = 400, 400 + fps = 10 + num_frames = 20 + video_writer = cv2.VideoWriter( + path, + cv2.VideoWriter_fourcc(*"mp4v"), + fps, + (video_width, video_height), + ) + + frames = np.zeros((num_frames, video_height, video_width, 3), dtype=np.uint8) + for frame in frames: + video_writer.write(frame) + video_writer.release() + def test_classification_models(self): with tempfile.TemporaryDirectory() as tmp_dirname: for model_name in {Models.RESNET18, Models.EFFICIENTNET_B0, Models.MOBILENET_V2, Models.REGNETY200}: @@ -86,6 +103,23 @@ def test_predict_class_names(self): with self.assertRaises(ValueError): _ = predictions.show(class_names=["human"]) + def test_predict_video(self): + with tempfile.TemporaryDirectory() as tmp_dirname: + video_path = os.path.join(tmp_dirname, "test.mp4") + self._prepare_video(video_path) + for model_name in [Models.YOLO_NAS_S, Models.YOLOX_S, Models.YOLO_NAS_POSE_S]: + + pretrained_weights = "coco" + if model_name == Models.YOLO_NAS_POSE_S: + pretrained_weights += "_pose" + model = models.get(model_name, pretrained_weights=pretrained_weights) + + predictions = model.predict(video_path) + predictions.save(os.path.join(tmp_dirname, "test_predict_video_detection.mp4")) + + predictions = model.predict(video_path) + predictions.save(os.path.join(tmp_dirname, "test_predict_video_detection.gif")) + def test_predict_detection_skip_resize(self): for model_name in [Models.YOLO_NAS_S, Models.YOLOX_S, Models.PP_YOLOE_S]: model = models.get(model_name, pretrained_weights="coco")