Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] lazy implementation of video predictions #1621

Merged
merged 24 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
06cf002
Draft code for replacing list predictions with generator for video, mp4
philmarchenko Nov 7, 2023
5d6423c
Changed types, united save_mp4, rewrote save_gif
philmarchenko Nov 8, 2023
1835cf3
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 8, 2023
e040b47
Returned correct type inside iterables in Video(Detection)Prediction
philmarchenko Nov 8, 2023
4e5caac
Changed types to more correct ones
philmarchenko Nov 8, 2023
cade080
Provided similar changes to PoseEstimationVideoPrediction
philmarchenko Nov 8, 2023
0b6ec0e
Merge branch 'feature/video-save-show-fix' of github.com:hakuryuu96/s…
philmarchenko Nov 8, 2023
fe6091e
Merge branch 'master' into feature/video-save-show-fix
BloodAxe Nov 13, 2023
ccf2357
Added tests of save and show for video predictions
philmarchenko Nov 13, 2023
51c6d0f
Added example script for PE and changed detection example script
philmarchenko Nov 13, 2023
7f68cb7
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 13, 2023
4095b83
Removed test warning filter
philmarchenko Nov 13, 2023
d6a3390
Merge branch 'feature/video-save-show-fix' of github.com:hakuryuu96/s…
philmarchenko Nov 13, 2023
d11b121
Removed unused import
philmarchenko Nov 13, 2023
f3bce78
Removed duplicated import
philmarchenko Nov 13, 2023
fe7cb37
Removed show() from video tests, not available
philmarchenko Nov 13, 2023
ec705cd
Fixed pretrained weights flag in test video
philmarchenko Nov 13, 2023
1e25265
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 13, 2023
a6096e5
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 14, 2023
b1a3e1b
Merge branch 'master' into feature/video-save-show-fix
BloodAxe Nov 14, 2023
e81ec61
Replaced link with an arg for video path
philmarchenko Nov 15, 2023
a79a58f
Added a documentation line regarding the video predictions
philmarchenko Nov 15, 2023
6f2c467
Changed word in progress bar example
philmarchenko Nov 15, 2023
8f4cb05
Merge branch 'master' into feature/video-save-show-fix
philmarchenko Nov 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ClassificationPrediction,
)
from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device
from super_gradients.training.utils.media.video import load_video, includes_video_extension
from super_gradients.training.utils.media.video import load_video, includes_video_extension, lazy_load_video
from super_gradients.training.utils.media.image import ImageSource, check_image_typing
from super_gradients.training.utils.media.stream import WebcamStreaming
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
Expand Down Expand Up @@ -133,9 +133,10 @@ def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> Vide
:param batch_size: The size of each batch.
:return: Results of the prediction.
"""
video_frames, fps = load_video(file_path=video_path)
video_frames, fps, num_frames = lazy_load_video(file_path=video_path)
result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size)
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))
return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=num_frames)
# return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))

def predict_webcam(self) -> None:
"""Predict using webcam"""
Expand Down Expand Up @@ -317,8 +318,8 @@ def _combine_image_prediction_to_images(
def _combine_image_prediction_to_video(
self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
) -> VideoDetectionPrediction:
images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps)
# images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)


class PoseEstimationPipeline(Pipeline):
Expand Down
63 changes: 61 additions & 2 deletions src/super_gradients/training/utils/media/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Generator
import cv2
import PIL

Expand Down Expand Up @@ -30,6 +30,15 @@ def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[n
return frames, fps


def lazy_load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[Generator, int]:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
cap = _open_video(file_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frames = _lazy_extract_frames(cap, max_frames)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# cap.release()
return frames, fps, num_frames


def _open_video(file_path: str) -> cv2.VideoCapture:
"""Open a video file.

Expand Down Expand Up @@ -61,6 +70,20 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
return frames


def _lazy_extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> Generator:
frames_counter = 0

while frames_counter != max_frames:
frame_read_success, frame = cap.read()
if not frame_read_success:
break

frames_counter += 1
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

cap.release()


def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.

Expand All @@ -75,7 +98,7 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
if check_is_gif(output_path):
save_gif(output_path, frames, fps)
else:
save_mp4(output_path, frames, fps)
lazy_save_mp4(output_path, frames, fps)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved


def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
Expand Down Expand Up @@ -113,6 +136,24 @@ def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
video_writer.release()


def lazy_save_mp4(output_path, frames, fps) -> None:
video_height, video_width, video_writer = None, None, None

for frame in frames:
if video_height is None:
video_height, video_width = frame.shape[:2]
video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
)
_validate_frame(frame, video_height, video_width)
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video_writer.release()


def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
"""Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))

Expand All @@ -137,6 +178,24 @@ def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
return max_height, max_width


def _validate_frame(frame, control_height, control_width) -> None:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))

:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:return: (Height, Weight) of the video.
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
height, width = frame.shape[:2]

if (height, width) != (control_height, control_width):
raise RuntimeError(
f"Current frame has resolution {height}x{width} but {control_height}x{control_width} is expected!"
f"Please make sure that all the frames have the same shape."
)

if frame.ndim != 3:
raise RuntimeError("Your frames must include 3 channels.")


def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
"""Display a video from disk using OpenCV.

Expand Down
26 changes: 14 additions & 12 deletions src/super_gradients/training/utils/predict/prediction_results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple, Iterator, Union
from typing import List, Optional, Tuple, Iterator, Union, Generator

import cv2
import numpy as np
Expand All @@ -16,6 +16,8 @@
from .predictions import Prediction, DetectionPrediction, ClassificationPrediction
from ...datasets.data_formats.bbox_formats import convert_bboxes

from tqdm import tqdm


@dataclass
class ImagePrediction(ABC):
Expand Down Expand Up @@ -325,15 +327,16 @@ def save(self, *args, **kwargs) -> None:


@dataclass
class VideoPredictions(ImagesPredictions, ABC):
class VideoPredictions(ABC):
"""Object wrapping the list of image predictions as a Video.

:attr _images_prediction_lst: List of results of the run
:attr _images_prediction_gen: List of results of the run
:att fps: Frames per second of the video
"""

_images_prediction_lst: List[ImagePrediction]
_images_prediction_gen: Generator
fps: float
n_frames: int

@abstractmethod
def show(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -504,20 +507,21 @@ def save(
class VideoDetectionPrediction(VideoPredictions):
"""Object wrapping the list of image detection predictions as a Video.

:attr _images_prediction_lst: List of the predictions results
:attr _images_prediction_gen: List of the predictions results
:att fps: Frames per second of the video
"""

_images_prediction_lst: List[ImageDetectionPrediction]
_images_prediction_gen: Generator
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
fps: int
n_frames: int

def draw(
self,
box_thickness: int = 2,
show_confidence: bool = True,
color_mapping: Optional[List[Tuple[int, int, int]]] = None,
class_names: Optional[List[str]] = None,
) -> List[np.ndarray]:
) -> Generator:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""Draw the predicted bboxes on the images.

:param box_thickness: Thickness of bounding boxes.
Expand All @@ -527,16 +531,14 @@ def draw(
:param class_names: List of class names to show. By default, is None which shows all classes using during training.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
"""
frames_with_bbox = [
result.draw(

for result in tqdm(self._images_prediction_gen, total=self.n_frames, desc="Processing Video"):
yield result.draw(
box_thickness=box_thickness,
show_confidence=show_confidence,
color_mapping=color_mapping,
class_names=class_names,
)
for result in self._images_prediction_lst
]
return frames_with_bbox

def show(
self,
Expand Down