Skip to content

Commit

Permalink
Dataset specific args method to CIFAR10, ImageNet, MNIST, and STL10 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
matsumotosan authored Sep 23, 2022
1 parent e97ae97 commit cbe4143
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from argparse import ArgumentParser
from typing import Any, Callable, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
Expand Down Expand Up @@ -122,6 +123,16 @@ def default_transforms(self) -> Callable:

return cf10_transforms

@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=32)

return parser


@under_review()
class TinyCIFAR10DataModule(CIFAR10DataModule):
Expand Down
11 changes: 11 additions & 0 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from argparse import ArgumentParser
from typing import Any, Callable, Optional

from pytorch_lightning import LightningDataModule
Expand Down Expand Up @@ -259,3 +260,13 @@ def val_transform(self) -> Callable:
]
)
return preprocessing

@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=32)

return parser
11 changes: 11 additions & 0 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from argparse import ArgumentParser
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
Expand Down Expand Up @@ -106,3 +107,13 @@ def default_transforms(self) -> Callable:
mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

return mnist_transforms

@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=32)

return parser
11 changes: 11 additions & 0 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from argparse import ArgumentParser
from typing import Any, Callable, Optional

import torch
Expand Down Expand Up @@ -304,3 +305,13 @@ def val_dataloader_labeled(self) -> DataLoader:
def _default_transforms(self) -> Callable:
data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()])
return data_transforms

@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=32)

return parser

0 comments on commit cbe4143

Please sign in to comment.