From cbe41439ca6a100d9685a5d4e0aaf84749af8d4b Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 23 Sep 2022 04:36:44 -0400 Subject: [PATCH] Dataset specific args method to CIFAR10, ImageNet, MNIST, and STL10 (#890) --- pl_bolts/datamodules/cifar10_datamodule.py | 11 +++++++++++ pl_bolts/datamodules/imagenet_datamodule.py | 11 +++++++++++ pl_bolts/datamodules/mnist_datamodule.py | 11 +++++++++++ pl_bolts/datamodules/stl10_datamodule.py | 11 +++++++++++ 4 files changed, 44 insertions(+) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 06799b5c8d..8526739b4f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -1,3 +1,4 @@ +from argparse import ArgumentParser from typing import Any, Callable, Optional, Sequence, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule @@ -122,6 +123,16 @@ def default_transforms(self) -> Callable: return cf10_transforms + @staticmethod + def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=32) + + return parser + @under_review() class TinyCIFAR10DataModule(CIFAR10DataModule): diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index e2c5ff24a6..a01e4dd4db 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from argparse import ArgumentParser from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule @@ -259,3 +260,13 @@ def val_transform(self) -> Callable: ] ) return preprocessing + + @staticmethod + def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=32) + + return parser diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index b953fb31f1..20b548d1c6 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,3 +1,4 @@ +from argparse import ArgumentParser from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule @@ -106,3 +107,13 @@ def default_transforms(self) -> Callable: mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) return mnist_transforms + + @staticmethod + def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=32) + + return parser diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index e765d45390..4774b84658 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,4 +1,5 @@ import os +from argparse import ArgumentParser from typing import Any, Callable, Optional import torch @@ -304,3 +305,13 @@ def val_dataloader_labeled(self) -> DataLoader: def _default_transforms(self) -> Callable: data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()]) return data_transforms + + @staticmethod + def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument("--data_dir", type=str, default=".") + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=32) + + return parser