Skip to content

Commit

Permalink
Add dataset categories to data modules (#1105)
Browse files Browse the repository at this point in the history
* Fix metadata path

* Add categories to data modules
  • Loading branch information
samet-akcay authored May 25, 2023
1 parent 0cd967b commit 5eff4e6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 7 deletions.
10 changes: 6 additions & 4 deletions src/anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
name="btech", url="https://avires.dimi.uniud.it/papers/btad/btad.zip", hash="c1fa4d56ac50dd50908ce04e81037a8e"
)

CATEGORIES = ("01", "02", "03")


def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
"""Create BTech samples by parsing the BTech data file structure.
Expand Down Expand Up @@ -134,7 +136,7 @@ class BTechDataset(AnomalibDataset):
>>> transform = get_transforms(image_size=256)
>>> dataset = BTechDataset(
... root='./datasets/BTech',
... category='leather',
... category='01',
... transform=transform,
... task="classification",
... is_train=True,
Expand Down Expand Up @@ -215,13 +217,13 @@ class BTech(AnomalibDataModule):
>>> from anomalib.data import BTech
>>> datamodule = BTech(
... root="./datasets/BTech",
... category="leather",
... category="01",
... image_size=256,
... train_batch_size=32,
... test_batch_size=32,
... eval_batch_size=32,
... num_workers=8,
... transform_config_train=None,
... transform_config_val=None,
... transform_config_eval=None,
... )
>>> datamodule.setup()
Expand Down
20 changes: 19 additions & 1 deletion src/anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
logger = logging.getLogger(__name__)


IMG_EXTENSIONS = [".png", ".PNG"]
IMG_EXTENSIONS = (".png", ".PNG")

DOWNLOAD_INFO = DownloadInfo(
name="mvtec",
Expand All @@ -57,6 +57,24 @@
hash="eefca59f2cede9c3fc5b6befbfec275e",
)

CATEGORIES = (
"bottle",
"cable",
"capsule",
"carpet",
"grid",
"hazelnut",
"leather",
"metal_nut",
"pill",
"screw",
"tile",
"toothbrush",
"transistor",
"wood",
"zipper",
)


def make_mvtec_dataset(
root: str | Path, split: str | Split | None = None, extensions: Sequence[str] | None = None
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/data/mvtec_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
hash="d8bb2800fbf3ac88e798da6ae10dc819",
)

CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire")


def make_mvtec_3d_dataset(
root: str | Path, split: str | Split | None = None, extensions: Sequence[str] | None = None
Expand Down
6 changes: 4 additions & 2 deletions src/anomalib/data/ucsd_ped.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
hash="5006421b89885f45a6f93b041145f2eb",
)

CATEGORIES = ("UCSDped1", "UCSDped2")


def make_ucsd_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
"""Create UCSD Pedestrian dataset by parsing the file structure.
Expand Down Expand Up @@ -151,7 +153,7 @@ class UCSDpedDataset(AnomalibVideoDataset):
Args:
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
root (Path | str): Path to the root of the dataset
category (str): Sub-category of the dataset, e.g. 'bottle'
category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2"
transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs.
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
clip_length_in_frames (int, optional): Number of video frames in each clip.
Expand Down Expand Up @@ -186,7 +188,7 @@ class UCSDped(AnomalibVideoDataModule):
Args:
root (Path | str): Path to the root of the dataset
category (str): Sub-category of the dataset, e.g. 'bottle'
category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2"
clip_length_in_frames (int, optional): Number of video frames in each clip.
frames_between_clips (int, optional): Number of frames between each consecutive video clip.
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval
Expand Down
15 changes: 15 additions & 0 deletions src/anomalib/data/visa.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@
hash="ef908989b6dc701fc218f643c127a4de",
)

CATEGORIES = (
"candle",
"capsules",
"cashew",
"chewinggum",
"fryum",
"macaroni1",
"macaroni2",
"pcb1",
"pcb2",
"pcb3",
"pcb4",
"pipe_fryum",
)


class VisaDataset(AnomalibDataset):
"""VisA dataset class.
Expand Down

0 comments on commit 5eff4e6

Please sign in to comment.