Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataset categories to data modules #1105

Merged
merged 15 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
name="btech", url="https://avires.dimi.uniud.it/papers/btad/btad.zip", hash="c1fa4d56ac50dd50908ce04e81037a8e"
)

CATEGORIES = ("01", "02", "03")


def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
"""Create BTech samples by parsing the BTech data file structure.
Expand Down Expand Up @@ -134,7 +136,7 @@ class BTechDataset(AnomalibDataset):
>>> transform = get_transforms(image_size=256)
>>> dataset = BTechDataset(
... root='./datasets/BTech',
... category='leather',
... category='01',
... transform=transform,
... task="classification",
... is_train=True,
Expand Down Expand Up @@ -215,13 +217,13 @@ class BTech(AnomalibDataModule):
>>> from anomalib.data import BTech
>>> datamodule = BTech(
... root="./datasets/BTech",
... category="leather",
... category="01",
... image_size=256,
... train_batch_size=32,
... test_batch_size=32,
... eval_batch_size=32,
... num_workers=8,
... transform_config_train=None,
... transform_config_val=None,
... transform_config_eval=None,
... )
>>> datamodule.setup()

Expand Down
20 changes: 19 additions & 1 deletion src/anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
logger = logging.getLogger(__name__)


IMG_EXTENSIONS = [".png", ".PNG"]
IMG_EXTENSIONS = (".png", ".PNG")

DOWNLOAD_INFO = DownloadInfo(
name="mvtec",
Expand All @@ -57,6 +57,24 @@
hash="eefca59f2cede9c3fc5b6befbfec275e",
)

CATEGORIES = (
"bottle",
"cable",
"capsule",
"carpet",
"grid",
"hazelnut",
"leather",
"metal_nut",
"pill",
"screw",
"tile",
"toothbrush",
"transistor",
"wood",
"zipper",
)


def make_mvtec_dataset(
root: str | Path, split: str | Split | None = None, extensions: Sequence[str] | None = None
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/data/mvtec_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
hash="d8bb2800fbf3ac88e798da6ae10dc819",
)

CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire")


def make_mvtec_3d_dataset(
root: str | Path, split: str | Split | None = None, extensions: Sequence[str] | None = None
Expand Down
6 changes: 4 additions & 2 deletions src/anomalib/data/ucsd_ped.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
hash="5006421b89885f45a6f93b041145f2eb",
)

CATEGORIES = ("UCSDped1", "UCSDped2")


def make_ucsd_dataset(path: Path, split: str | Split | None = None) -> DataFrame:
"""Create UCSD Pedestrian dataset by parsing the file structure.
Expand Down Expand Up @@ -151,7 +153,7 @@ class UCSDpedDataset(AnomalibVideoDataset):
Args:
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
root (Path | str): Path to the root of the dataset
category (str): Sub-category of the dataset, e.g. 'bottle'
category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2"
transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs.
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
clip_length_in_frames (int, optional): Number of video frames in each clip.
Expand Down Expand Up @@ -186,7 +188,7 @@ class UCSDped(AnomalibVideoDataModule):

Args:
root (Path | str): Path to the root of the dataset
category (str): Sub-category of the dataset, e.g. 'bottle'
category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2"
clip_length_in_frames (int, optional): Number of video frames in each clip.
frames_between_clips (int, optional): Number of frames between each consecutive video clip.
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval
Expand Down
15 changes: 15 additions & 0 deletions src/anomalib/data/visa.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@
hash="ef908989b6dc701fc218f643c127a4de",
)

CATEGORIES = (
"candle",
"capsules",
"cashew",
"chewinggum",
"fryum",
"macaroni1",
"macaroni2",
"pcb1",
"pcb2",
"pcb3",
"pcb4",
"pipe_fryum",
)


class VisaDataset(AnomalibDataset):
"""VisA dataset class.
Expand Down