From de5894666c6930ffee726e63952bc78039849eae Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 23 Apr 2022 17:57:51 +0800 Subject: [PATCH 1/7] [Features] Intialize dataset with ann_file --- mmflow/datasets/base_dataset.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index f77b83cb..1660b7d9 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp +import warnings from abc import ABCMeta, abstractmethod from typing import Optional, Sequence, Union @@ -16,6 +17,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta): Args: data_root (str): Directory for dataset. pipeline (Sequence[dict]): Processing pipeline. + ann_file: Annotation file path. Defaults to None. test_mode (bool): Whether the dataset works for model testing or training. """ @@ -23,12 +25,15 @@ class BaseDataset(Dataset, metaclass=ABCMeta): def __init__(self, data_root: str, pipeline: Sequence[dict], + ann_file: Optional[str] = None, + file_client_args: dict = dict(backend='disk'), test_mode: bool = False) -> None: super().__init__() self.data_root = data_root self.pipeline = Compose(pipeline) self.test_mode = test_mode self.dataset_name = self.__class__.__name__ + self.file_client_args = file_client_args """ data_infos is the list of data_info containing img_info and ann_info data_info @@ -41,7 +46,27 @@ def __init__(self, """ self.data_infos = [] - self.load_data_info() + if ann_file is None: + warnings.warn(message='ann_file is None, please use ' + 'tools/prepare_dataset to generate ann_file') + self.load_data_info() + else: + self.load_ann_file(ann_file) + + def load_ann_file(self, ann_file): + ann = mmcv.load( + ann_file, + file_format='json', + file_client_args=self.file_client_args) + self.data_infos = ann['data_list'] + + for data_info in self.data_infos: + data_info['filename1'] = \ + osp.join(self.data_root, data_info['filename1']) + data_info['filename2'] = \ + osp.join(self.data_root, data_info['filename2']) + data_info['filename_flow'] = \ + osp.join(self.data_root, data_info['filename_flow']) @abstractmethod def load_data_info(self): From 9aaec502d42a21c497d266a92a233001adff1caa Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 23 Apr 2022 23:22:40 +0800 Subject: [PATCH 2/7] add ann parse --- mmflow/datasets/base_dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index 1660b7d9..63a9f9ed 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -59,12 +59,14 @@ def load_ann_file(self, ann_file): file_format='json', file_client_args=self.file_client_args) self.data_infos = ann['data_list'] - + self.img1_dir = self.data_infos[0]['img1_dir'] + self.img2_dir = self.data_infos[0]['img2_dir'] + self.flow_dir = self.data_infos[0]['flow_dir'] for data_info in self.data_infos: - data_info['filename1'] = \ - osp.join(self.data_root, data_info['filename1']) - data_info['filename2'] = \ - osp.join(self.data_root, data_info['filename2']) + data_info['img_info']['filename1'] = \ + osp.join(self.img1_dir, data_info['filename1']) + data_info['img_info']['filename2'] = \ + osp.join(self.img2_dir, data_info['filename2']) data_info['filename_flow'] = \ osp.join(self.data_root, data_info['filename_flow']) From 8e1e4b03856c2a1190340bc8bf135003f0a033b5 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 23 Apr 2022 23:29:11 +0800 Subject: [PATCH 3/7] fix --- mmflow/datasets/base_dataset.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index 63a9f9ed..d72af500 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -54,14 +54,22 @@ def __init__(self, self.load_ann_file(ann_file) def load_ann_file(self, ann_file): + """_summary_ + + Args: + ann_file (_type_): _description_ + """ ann = mmcv.load( ann_file, file_format='json', file_client_args=self.file_client_args) self.data_infos = ann['data_list'] - self.img1_dir = self.data_infos[0]['img1_dir'] - self.img2_dir = self.data_infos[0]['img2_dir'] - self.flow_dir = self.data_infos[0]['flow_dir'] + self.img1_dir = osp.join(self.data_root, + self.data_infos[0]['img1_dir']) + self.img2_dir = osp.join(self.data_root, + self.data_infos[0]['img2_dir']) + self.flow_dir = osp.join(self.data_root, + self.data_infos[0]['flow_dir']) for data_info in self.data_infos: data_info['img_info']['filename1'] = \ osp.join(self.img1_dir, data_info['filename1']) From a8a5922bc8e5c2bd76f5cece4f62075ab34f029a Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 23 Apr 2022 23:38:22 +0800 Subject: [PATCH 4/7] img info ann info --- mmflow/datasets/base_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index d72af500..57f6efe6 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -72,11 +72,11 @@ def load_ann_file(self, ann_file): self.data_infos[0]['flow_dir']) for data_info in self.data_infos: data_info['img_info']['filename1'] = \ - osp.join(self.img1_dir, data_info['filename1']) + osp.join(self.img1_dir, data_info['img_info']['filename1']) data_info['img_info']['filename2'] = \ - osp.join(self.img2_dir, data_info['filename2']) - data_info['filename_flow'] = \ - osp.join(self.data_root, data_info['filename_flow']) + osp.join(self.img2_dir, data_info['img_info']['filename2']) + data_info['ann_info']['filename_flow'] = osp.join( + self.data_root, data_info['ann_info']['filename_flow']) @abstractmethod def load_data_info(self): From 69e8586ce6847f4996c8c31867fa5c205396e699 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sun, 24 Apr 2022 10:56:29 +0800 Subject: [PATCH 5/7] docstring --- mmflow/datasets/base_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index 57f6efe6..60748fbf 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -53,11 +53,12 @@ def __init__(self, else: self.load_ann_file(ann_file) - def load_ann_file(self, ann_file): - """_summary_ + def load_ann_file(self, ann_file: str) -> None: + """Load annotation file. Args: - ann_file (_type_): _description_ + ann_file (str): The json file contains the data sample + information. """ ann = mmcv.load( ann_file, From 502942e54ed7f01912305fcfa9d2e5266377b9b7 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sun, 24 Apr 2022 11:20:46 +0800 Subject: [PATCH 6/7] fix flow_dir --- mmflow/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index 60748fbf..06f54bef 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -77,7 +77,7 @@ def load_ann_file(self, ann_file: str) -> None: data_info['img_info']['filename2'] = \ osp.join(self.img2_dir, data_info['img_info']['filename2']) data_info['ann_info']['filename_flow'] = osp.join( - self.data_root, data_info['ann_info']['filename_flow']) + self.flow_dir, data_info['ann_info']['filename_flow']) @abstractmethod def load_data_info(self): From 0fe8eacf2790bd65a6d2dec39dd95d8c73176903 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sun, 24 Apr 2022 11:21:52 +0800 Subject: [PATCH 7/7] fix docstring --- mmflow/datasets/base_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmflow/datasets/base_dataset.py b/mmflow/datasets/base_dataset.py index 06f54bef..b710247d 100644 --- a/mmflow/datasets/base_dataset.py +++ b/mmflow/datasets/base_dataset.py @@ -18,6 +18,9 @@ class BaseDataset(Dataset, metaclass=ABCMeta): data_root (str): Directory for dataset. pipeline (Sequence[dict]): Processing pipeline. ann_file: Annotation file path. Defaults to None. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. test_mode (bool): Whether the dataset works for model testing or training. """