Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validator definitions #303

Merged
merged 3 commits into from
Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Subformat importers for VOC and COCO (<https://github.com/openvinotoolkit/datumaro/pull/281>)
- Support for KITTI dataset segmentation and detection format (<https://github.com/openvinotoolkit/datumaro/pull/282>)
- Updated YOLO format user manual (<https://github.com/openvinotoolkit/datumaro/pull/295>)
- A base class for dataset validation plugins (<https://github.com/openvinotoolkit/datumaro/pull/299>)

### Changed
-
Expand Down
17 changes: 13 additions & 4 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datumaro.components.project import \
PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from datumaro.components.validator import Validator, TaskType
from datumaro.components.validator import TaskType
from datumaro.util import error_rollback

from ...util import (CliException, MultilineFormatter, add_subparser,
Expand Down Expand Up @@ -794,17 +794,26 @@ def print_extractor_info(extractor, indent=''):
return 0

def build_validate_parser(parser_ctor=argparse.ArgumentParser):
def _parse_task_type(s):
try:
return TaskType[s.lower()].name
except:
raise argparse.ArgumentTypeError("Unknown task type %s. Expected "
"one of: %s" % (s, ', '.join(t.name for t in TaskType)))


parser = parser_ctor(help="Validate project",
description="""
Validates project based on specified task type and stores
results like statistics, reports and summary in JSON file.
""",
formatter_class=MultilineFormatter)

parser.add_argument('-t', '--task_type', choices=[task_type.name for task_type in TaskType],
help="Task type for validation")
parser.add_argument('-t', '--task_type', type=_parse_task_type,
help="Task type for validation, one of %s" % \
', '.join(t.name for t in TaskType))
parser.add_argument('-s', '--subset', dest='subset_name', default=None,
help="Subset to validate (default: None)")
help="Subset to validate (default: whole dataset)")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to validate (default: current dir)")
parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None,
Expand Down
52 changes: 45 additions & 7 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,55 @@ class TaskType(Enum):
segmentation = auto()


class IValidator:
class Validator:
def validate(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
"""
Returns the validation results of a dataset based on task type.

Args:
dataset (IDataset): Dataset to be validated

class Validator(IValidator):
def validate(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
Raises:
ValueError

Returns:
validation_results (dict):
Dict with validation statistics, reports and summary.
"""

validation_results = {}
if not isinstance(dataset, IDataset):
raise TypeError("Invalid dataset type '%s'" % type(dataset))

# generate statistics
stats = self.compute_statistics(dataset)
validation_results['statistics'] = stats

# generate validation reports and summary
reports = self.generate_reports(stats)
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
validation_results['summary'] = summary

return validation_results

def compute_statistics(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
"""
Computes statistics of the dataset based on task type.

Args:
dataset (IDataset): a dataset to be validated

Returns:
stats (dict): A dict object containing statistics of the dataset.
"""
raise NotImplementedError("Must be implemented in a subclass")

def generate_reports(self, stats: Dict) -> List[Dict]:
raise NotImplementedError()
raise NotImplementedError("Must be implemented in a subclass")
151 changes: 24 additions & 127 deletions datumaro/plugins/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
# SPDX-License-Identifier: MIT

from copy import deepcopy
from typing import Dict, List

import json
import logging as log

import numpy as np

from datumaro.components.validator import (Severity, TaskType, Validator)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import IDataset
from datumaro.components.errors import (MissingLabelCategories,
MissingAnnotation, MultiLabelAnnotations, MissingAttribute,
UndefinedLabel, UndefinedAttribute, LabelDefinedButNotFound,
Expand All @@ -25,7 +20,7 @@
from datumaro.util import parse_str_enum_value


class _TaskValidator(Validator):
class _TaskValidator(Validator, CliPlugin):
# statistics templates
numerical_stat_template = {
'items_far_from_mean': {},
Expand All @@ -48,17 +43,28 @@ class _TaskValidator(Validator):
----------
task_type : str or TaskType
task type (ie. classification, detection, segmentation)

Methods
-------
validate(dataset):
Validate annotations based on task type.
compute_statistics(dataset):
Computes various statistics of the dataset based on task type.
generate_reports(stats):
Abstract method that must be implemented in a subclass.
"""

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, task_type, few_samples_thr=None,
imbalance_ratio_thr=None, far_from_mean_thr=None,
dominance_ratio_thr=None, topk_bins=None):
Expand Down Expand Up @@ -102,41 +108,6 @@ def __init__(self, task_type, few_samples_thr=None,
self.dominance_thr = dominance_ratio_thr
self.topk_bins_ratio = topk_bins

def validate(self, dataset: IDataset):
"""
Returns the validation results of a dataset based on task type.
Args:
dataset (IDataset): Dataset to be validated
task_type (str or TaskType): Type of the task
(classification, detection, segmentation)
Raises:
ValueError
Returns:
validation_results (dict):
Dict with validation statistics, reports and summary.
"""
validation_results = {}
if not isinstance(dataset, IDataset):
raise TypeError("Invalid dataset type '%s'" % type(dataset))

# generate statistics
stats = self.compute_statistics(dataset)
validation_results['statistics'] = stats

# generate validation reports and summary
reports = self.generate_reports(stats)
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
validation_results['summary'] = summary

return validation_results

def _compute_common_statistics(self, dataset):
defined_attr_template = {
'items_missing_attribute': [],
Expand Down Expand Up @@ -285,20 +256,6 @@ def _far_from_mean(val, mean, stdev):
item_key, {})
far_from_mean[ann.id] = val

def compute_statistics(self, dataset: IDataset):
"""
Computes statistics of the dataset based on task type.

Parameters
----------
dataset : IDataset object

Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
return NotImplementedError

def _check_missing_label_categories(self, stats):
validation_reports = []

Expand Down Expand Up @@ -578,36 +535,14 @@ def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats):

return validation_reports

def generate_reports(self, stats: Dict) -> List[Dict]:
raise NotImplementedError('Should be implemented in a subclass.')

def _generate_validation_report(self, error, *args, **kwargs):
return [error(*args, **kwargs)]


class ClassificationValidator(_TaskValidator, CliPlugin):
class ClassificationValidator(_TaskValidator):
"""
A specific validator class for classification task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
Expand Down Expand Up @@ -709,29 +644,10 @@ def generate_reports(self, stats):
return reports


class DetectionValidator(_TaskValidator, CliPlugin):
class DetectionValidator(_TaskValidator):
"""
A specific validator class for detection task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
Expand Down Expand Up @@ -1014,29 +930,10 @@ def generate_reports(self, stats):
return reports


class SegmentationValidator(_TaskValidator, CliPlugin):
class SegmentationValidator(_TaskValidator):
"""
A specific validator class for (instance) segmentation task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
FarFromAttrMean, OnlyOneAttributeValue)
from datumaro.components.extractor import Bbox, Label, Mask, Polygon
from datumaro.components.validator import TaskType
from datumaro.plugins.validators import (_TaskValidator, ClassificationValidator, DetectionValidator, SegmentationValidator)
from datumaro.plugins.validators import (_TaskValidator,
ClassificationValidator, DetectionValidator, SegmentationValidator)
from .requirements import Requirements, mark_requirement


Expand Down