From 34d0cdda64111aa92c9369ae5ca9731940ba2282 Mon Sep 17 00:00:00 2001 From: "Yi, Jihyeon" Date: Wed, 8 May 2024 17:29:39 +0900 Subject: [PATCH] add and update test codes regarding video dataset --- src/datumaro/components/media.py | 65 +++++++++-- .../data_formats/datumaro_binary/base.py | 3 +- tests/unit/data_formats/datumaro/conftest.py | 52 ++++++++- .../datumaro/test_datumaro_binary_format.py | 22 ++-- .../datumaro/test_datumaro_format.py | 21 +++- tests/unit/test_kinetics_format.py | 101 +++++++++++------- tests/unit/test_video.py | 16 ++- tests/utils/test_utils.py | 35 +++--- 8 files changed, 238 insertions(+), 77 deletions(-) diff --git a/src/datumaro/components/media.py b/src/datumaro/components/media.py index 1c99170fef..243e47d434 100644 --- a/src/datumaro/components/media.py +++ b/src/datumaro/components/media.py @@ -94,7 +94,7 @@ def media(self) -> Optional[Type[MediaElement]]: class MediaElement(Generic[AnyData]): _type = MediaType.MEDIA_ELEMENT - def __init__(self, crypter: Crypter = NULL_CRYPTER) -> None: + def __init__(self, crypter: Crypter = NULL_CRYPTER, *args, **kwargs) -> None: self._crypter = crypter def as_dict(self) -> Dict[str, Any]: @@ -488,6 +488,26 @@ def video(self) -> Video: def path(self) -> str: return self._video.path + def from_self(self, **kwargs): + attrs = deepcopy(self.as_dict()) + if "path" in kwargs: + attrs.update({"video": self.video.from_self(**kwargs)}) + kwargs.pop("path") + attrs.update(kwargs) + return self.__class__(**attrs) + + def __getstate__(self): + # Return only the picklable parts of the state. + state = self.__dict__.copy() + del state["_data"] + return state + + def __setstate__(self, state): + # Restore the objects' state. + self.__dict__.update(state) + # Reinitialize unpichlable attributes + self._data = lambda: self._video.get_frame_data(self._index) + class _VideoFrameIterator(Iterator[VideoFrame]): """ @@ -527,6 +547,11 @@ def _decode(self, cap) -> Iterator[VideoFrame]: if self._video._frame_count is None: self._video._frame_count = self._pos + 1 + if self._video._end_frame and self._video._end_frame >= self._video._frame_count: + raise ValueError( + f"The end_frame value({self._video._end_frame}) of the video " + f"must be less than the frame count({self._video._frame_count})." + ) def _make_frame(self, index) -> VideoFrame: return VideoFrame(self._video, index=index) @@ -575,14 +600,22 @@ class Video(MediaElement, Iterable[VideoFrame]): """ def __init__( - self, path: str, *, step: int = 1, start_frame: int = 0, end_frame: Optional[int] = None + self, + path: str, + step: int = 1, + start_frame: int = 0, + end_frame: Optional[int] = None, + *args, + **kwargs, ) -> None: - super().__init__() + super().__init__(*args, **kwargs) self._path = path assert 0 <= start_frame if end_frame: assert start_frame <= end_frame + # we can't know the video length here, + # so we cannot validate if the end_frame is valid. assert 0 < step self._step = step self._start_frame = start_frame @@ -727,12 +760,26 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, __class__): return False - return ( - self.path == other.path - and self._start_frame == other._start_frame - and self._step == other._step - and self._end_frame == other._end_frame - ) + if ( + self._start_frame != other._start_frame + or self._step != other._step + or self._end_frame != other._end_frame + ): + return False + + # The video path can vary if a dataset is copied. + # So, we need to check if the video data is the same instead of checking paths. + if self._end_frame is None: + # Decoding is not necessary to get frame pointers + # However, it can be inacurrate + end_frame = self._get_end_frame() + for index in range(self._start_frame, end_frame + 1, self._step): + yield VideoFrame(video=self, index=index) + for frame_self, frame_other in zip(self, other): + if frame_self != frame_other: + return False + + return True def __hash__(self): # Required for caching diff --git a/src/datumaro/plugins/data_formats/datumaro_binary/base.py b/src/datumaro/plugins/data_formats/datumaro_binary/base.py index 3616f22b93..efb00ca5f0 100644 --- a/src/datumaro/plugins/data_formats/datumaro_binary/base.py +++ b/src/datumaro/plugins/data_formats/datumaro_binary/base.py @@ -125,7 +125,8 @@ def _read_items(self) -> None: media_path_prefix = { MediaType.IMAGE: osp.join(self._images_dir, self._subset), MediaType.POINT_CLOUD: osp.join(self._pcd_dir, self._subset), - MediaType.VIDEO_FRAME: self._video_dir, + MediaType.VIDEO: osp.join(self._video_dir, self._subset), + MediaType.VIDEO_FRAME: osp.join(self._video_dir, self._subset), } if self._num_workers > 0: diff --git a/tests/unit/data_formats/datumaro/conftest.py b/tests/unit/data_formats/datumaro/conftest.py index 6707c2c637..a6f750fd97 100644 --- a/tests/unit/data_formats/datumaro/conftest.py +++ b/tests/unit/data_formats/datumaro/conftest.py @@ -28,11 +28,13 @@ RleMask, ) from datumaro.components.dataset_base import DatasetItem -from datumaro.components.media import Image, PointCloud +from datumaro.components.media import Image, MediaElement, PointCloud, Video, VideoFrame from datumaro.components.project import Dataset from datumaro.plugins.data_formats.datumaro.format import DatumaroPath from datumaro.util.mask_tools import generate_colormap +from tests.utils.video import make_sample_video + @pytest.fixture def fxt_test_datumaro_format_dataset(): @@ -199,6 +201,54 @@ def fxt_test_datumaro_format_dataset(): ) +@pytest.fixture +def fxt_test_datumaro_format_video_dataset(test_dir) -> Dataset: + video_path = osp.join(test_dir, "video.avi") + make_sample_video(video_path, frame_size=(4, 6), frames=4) + video = Video(video_path) + + return Dataset.from_iterable( + iterable=[ + DatasetItem( + "f0", + subset="train", + media=VideoFrame(video, 0), + annotations=[ + Bbox(1, 1, 1, 1, label=0, object_id=0), + Bbox(2, 2, 2, 2, label=1, object_id=1), + ], + ), + DatasetItem( + "f1", + subset="test", + media=VideoFrame(video, 0), + annotations=[ + Bbox(0, 0, 2, 2, label=1, object_id=1), + Bbox(3, 3, 1, 1, label=0, object_id=0), + ], + ), + DatasetItem( + "v0", + subset="train", + media=Video(video_path, step=1, start_frame=0, end_frame=1), + annotations=[ + Label(0), + ], + ), + DatasetItem( + "v1", + subset="test", + media=Video(video_path, step=1, start_frame=2, end_frame=2), + annotations=[ + Bbox(1, 1, 3, 3, label=1, object_id=1), + ], + ), + ], + media_type=MediaElement, + categories=["a", "b"], + ) + + @pytest.fixture def fxt_wrong_version_dir(fxt_test_datumaro_format_dataset, test_dir): dest_dir = osp.join(test_dir, "wrong_version") diff --git a/tests/unit/data_formats/datumaro/test_datumaro_binary_format.py b/tests/unit/data_formats/datumaro/test_datumaro_binary_format.py index 00700bb1e7..68cde7de1d 100644 --- a/tests/unit/data_formats/datumaro/test_datumaro_binary_format.py +++ b/tests/unit/data_formats/datumaro/test_datumaro_binary_format.py @@ -39,10 +39,13 @@ class DatumaroBinaryFormatTest(TestBase): ann_ext = DatumaroBinaryPath.ANNOTATION_EXT @pytest.mark.parametrize( - ["fxt_dataset", "compare", "require_media", "fxt_import_kwargs", "fxt_export_kwargs"], + "fxt_dataset", + ("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"), + ) + @pytest.mark.parametrize( + ["compare", "require_media", "fxt_import_kwargs", "fxt_export_kwargs"], [ pytest.param( - "fxt_test_datumaro_format_dataset", compare_datasets_strict, True, {}, @@ -50,7 +53,6 @@ class DatumaroBinaryFormatTest(TestBase): id="test_no_encryption", ), pytest.param( - "fxt_test_datumaro_format_dataset", compare_datasets_strict, True, {"encryption_key": ENCRYPTION_KEY}, @@ -58,7 +60,6 @@ class DatumaroBinaryFormatTest(TestBase): id="test_with_encryption", ), pytest.param( - "fxt_test_datumaro_format_dataset", compare_datasets_strict, True, {"encryption_key": ENCRYPTION_KEY}, @@ -66,7 +67,6 @@ class DatumaroBinaryFormatTest(TestBase): id="test_no_media_encryption", ), pytest.param( - "fxt_test_datumaro_format_dataset", compare_datasets_strict, True, {"encryption_key": ENCRYPTION_KEY}, @@ -74,7 +74,6 @@ class DatumaroBinaryFormatTest(TestBase): id="test_multi_blobs", ), pytest.param( - "fxt_test_datumaro_format_dataset", compare_datasets_strict, True, {"encryption_key": ENCRYPTION_KEY, "num_workers": 2}, @@ -167,10 +166,15 @@ def _get_ann_mapper(ann: Annotation) -> AnnotationMapper: def test_common_mapper(self, mapper: Mapper, expected: Any): self._test(mapper, expected) - def test_annotations_mapper(self, fxt_test_datumaro_format_dataset): - """Test all annotations in fxt_test_datumaro_format_dataset""" + @pytest.mark.parametrize( + "fxt_dataset", + ("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"), + ) + def test_annotations_mapper(self, fxt_dataset, request): + """Test all annotations in fxt_dataset""" mapper = DatasetItemMapper - for item in fxt_test_datumaro_format_dataset: + fxt_dataset = request.getfixturevalue(fxt_dataset) + for item in fxt_dataset: for ann in item.annotations: mapper = self._get_ann_mapper(ann) self._test(mapper, ann) diff --git a/tests/unit/data_formats/datumaro/test_datumaro_format.py b/tests/unit/data_formats/datumaro/test_datumaro_format.py index 9a3168df83..5cd4142604 100644 --- a/tests/unit/data_formats/datumaro/test_datumaro_format.py +++ b/tests/unit/data_formats/datumaro/test_datumaro_format.py @@ -76,6 +76,18 @@ def _test_save_and_load( False, id="test_can_save_and_load_with_no_save_media", ), + pytest.param( + "fxt_test_datumaro_format_video_dataset", + compare_datasets, + True, + id="test_can_save_and_load_video_dataset", + ), + pytest.param( + "fxt_test_datumaro_format_video_dataset", + None, + False, + id="test_can_save_and_load_video_dataset_with_no_save_media", + ), pytest.param( "fxt_relative_paths", compare_datasets, @@ -176,8 +188,13 @@ def test_source_target_pair( ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_detect(self, fxt_test_datumaro_format_dataset, test_dir): - self.exporter.convert(fxt_test_datumaro_format_dataset, save_dir=test_dir) + @pytest.mark.parametrize( + "fxt_dataset", + ("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"), + ) + def test_can_detect(self, fxt_dataset, test_dir, request): + fxt_dataset = request.getfixturevalue(fxt_dataset) + self.exporter.convert(fxt_dataset, save_dir=test_dir) detected_formats = Environment().detect_dataset(test_dir) assert [self.importer.NAME] == detected_formats diff --git a/tests/unit/test_kinetics_format.py b/tests/unit/test_kinetics_format.py index 586ddbe900..27a8f32492 100644 --- a/tests/unit/test_kinetics_format.py +++ b/tests/unit/test_kinetics_format.py @@ -1,4 +1,7 @@ -from unittest import TestCase +import os +import os.path as osp + +import pytest from datumaro.components.annotation import Label from datumaro.components.dataset import Dataset, DatasetItem @@ -11,46 +14,68 @@ from tests.utils.assets import get_test_asset_path from tests.utils.test_utils import compare_datasets -DUMMY_DATASET_DIR = get_test_asset_path("kinetics_dataset") +KINETICS_DATASET_DIR = get_test_asset_path("kinetics_dataset") + + +@pytest.fixture +def fxt_kinetics_dataset(test_dir): + def make_video(fname, frame_size=(4, 6), frames=4): + src_path = osp.join(KINETICS_DATASET_DIR, fname) + dst_path = osp.join(test_dir, fname) + if not osp.exists(osp.dirname(dst_path)): + os.makedirs(osp.dirname(dst_path)) + os.symlink(src_path, dst_path) + return Video(dst_path) + return Dataset.from_iterable( + [ + DatasetItem( + id="1", + subset="test", + annotations=[Label(0, attributes={"time_start": 0, "time_end": 2})], + media=make_video("video_1.avi"), + ), + DatasetItem( + id="2", + subset="test", + annotations=[Label(0, attributes={"time_start": 5, "time_end": 7})], + ), + DatasetItem( + id="4", + subset="test", + annotations=[Label(1, attributes={"time_start": 10, "time_end": 15})], + ), + DatasetItem( + id="3", + subset="train", + annotations=[Label(2, attributes={"time_start": 0, "time_end": 2})], + media=make_video("train/3.avi"), + ), + ], + categories=["label_0", "label_1", "label_2"], + media_type=Video, + ) -class KineticsImporterTest(TestCase): + +class KineticsImporterTest: @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_detect(self): - detected_formats = Environment().detect_dataset(DUMMY_DATASET_DIR) - self.assertEqual([KineticsImporter.NAME], detected_formats) + detected_formats = Environment().detect_dataset(KINETICS_DATASET_DIR) + assert [KineticsImporter.NAME] == detected_formats @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_import_with_video(self): - expected_dataset = Dataset.from_iterable( - [ - DatasetItem( - id="1", - subset="test", - annotations=[Label(0, attributes={"time_start": 0, "time_end": 2})], - media=Video("./video_1.avi"), - ), - DatasetItem( - id="2", - subset="test", - annotations=[Label(0, attributes={"time_start": 5, "time_end": 7})], - ), - DatasetItem( - id="4", - subset="test", - annotations=[Label(1, attributes={"time_start": 10, "time_end": 15})], - ), - DatasetItem( - id="3", - subset="train", - annotations=[Label(2, attributes={"time_start": 0, "time_end": 2})], - media=Video("./train/3.avi"), - ), - ], - categories=["label_0", "label_1", "label_2"], - media_type=Video, - ) - - imported_dataset = Dataset.import_from(DUMMY_DATASET_DIR, "kinetics") - - compare_datasets(self, expected_dataset, imported_dataset, require_media=True) + def test_can_import_with_video(self, helper_tc, fxt_kinetics_dataset): + expected_dataset = fxt_kinetics_dataset + imported_dataset = Dataset.import_from(KINETICS_DATASET_DIR, "kinetics") + + compare_datasets(helper_tc, expected_dataset, imported_dataset, require_media=True) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_convert_to_datumaro_and_export_it(self, helper_tc, test_dir): + imported_dataset = Dataset.import_from(KINETICS_DATASET_DIR, "kinetics") + export_dir = osp.join(test_dir, "dst") + imported_dataset.export(export_dir, "datumaro", save_media=True) + + exported_dataset = Dataset.import_from(export_dir, "datumaro") + + compare_datasets(helper_tc, imported_dataset, exported_dataset, require_media=True) diff --git a/tests/unit/test_video.py b/tests/unit/test_video.py index 0e69e02a1c..574385c098 100644 --- a/tests/unit/test_video.py +++ b/tests/unit/test_video.py @@ -47,12 +47,23 @@ def test_can_read_frames_sequentially(self, fxt_sample_video): video = Video(fxt_sample_video) on_exit_do(video.close) + assert None == video._frame_count for idx, frame in enumerate(video): assert frame.size == video.frame_size assert frame.index == idx assert frame.video is video assert np.array_equal(frame.data, np.ones((*video.frame_size, 3)) * idx) + assert 4 == video._frame_count + for idx, frame in enumerate(video): + assert frame.size == video.frame_size + assert frame.index == idx + assert frame.video is video + assert np.array_equal(frame.data, np.ones((*video.frame_size, 3)) * idx) + + with pytest.raises(IndexError): + video.get_frame_data(idx + 1) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped def test_can_read_frames_randomly(self, fxt_sample_video): @@ -64,6 +75,9 @@ def test_can_read_frames_randomly(self, fxt_sample_video): assert frame.index == idx assert np.array_equal(frame.data, np.ones((*video.frame_size, 3)) * idx) + with pytest.raises(IndexError): + frame = video[4] + @mark_requirement(Requirements.DATUM_GENERAL_REQ) @scoped def test_can_skip_frames_between(self, fxt_sample_video): @@ -151,7 +165,7 @@ def test_can_split_and_load(self, fxt_sample_video): ) dataset = Dataset.import_from( - fxt_sample_video, "video_frames", start_frame=0, end_frame=4, name_pattern="frame_%06d" + fxt_sample_video, "video_frames", start_frame=0, end_frame=3, name_pattern="frame_%06d" ) dataset.export(format="image_dir", save_dir=test_dir, image_ext=".jpg") diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 9f0cd9e836..b11e1fa0d0 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -21,7 +21,7 @@ from datumaro.components.annotation import AnnotationType from datumaro.components.dataset import Dataset, StreamDataset from datumaro.components.dataset_base import IDataset -from datumaro.components.media import Image, MultiframeImage, PointCloud, VideoFrame +from datumaro.components.media import Image, MultiframeImage, PointCloud, Video, VideoFrame from datumaro.util import filter_dict, find from datumaro.util.os_util import rmfile, rmtree @@ -204,6 +204,8 @@ def compare_datasets( elif isinstance(item_a.media, PointCloud): test.assertEqual(item_a.media.data, item_b.media.data, item_a.id) test.assertEqual(item_a.media.extra_images, item_b.media.extra_images, item_a.id) + elif isinstance(item_a.media, Video): + test.assertEqual(item_a.media, item_b.media, item_a.id) elif isinstance(item_a.media, VideoFrame): test.assertEqual(item_a.media, item_b.media, item_a.id) test.assertEqual(item_a.index, item_b.index, item_a.id) @@ -323,27 +325,28 @@ def check_save_and_load( def _change_path_in_items(dataset, source_path, target_path): for item in dataset: - if item.media and hasattr(item.media, "path"): - path = item.media._path - item.media = item.media.from_self(path=path.replace(source_path, target_path)) - if item.media and isinstance(item.media, PointCloud): - new_images = [] - for image in item.media.extra_images: - if hasattr(image, "path"): - path = image._path - new_images.append( - image.from_self(path=path.replace(source_path, target_path)) - ) - else: - new_images.append(image) - item.media._extra_images = new_images + if item.media: + if hasattr(item.media, "path") and item.media.path: + path = item.media.path.replace(source_path, target_path) + item.media = item.media.from_self(path=path) + if isinstance(item.media, PointCloud): + new_images = [] + for image in item.media.extra_images: + if hasattr(image, "path"): + path = image._path + new_images.append( + image.from_self(path=path.replace(source_path, target_path)) + ) + else: + new_images.append(image) + item.media._extra_images = new_images with TestDir() as tmp_dir: converter(source_dataset, test_dir, stream=stream) if move_save_dir: save_dir = tmp_dir for file in os.listdir(test_dir): - shutil.move(osp.join(test_dir, file), save_dir) + os.symlink(osp.join(test_dir, file), osp.join(save_dir, file)) else: save_dir = test_dir