Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaitreChen committed Apr 29, 2023
1 parent d52e954 commit 3275339
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 32 deletions.
4 changes: 4 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ root2: './data/fake'
# model input size
image_size: 224

# properties
mean: [0.00924097]
std: [0.00282327]

53 changes: 24 additions & 29 deletions utils/classifier_dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# IMPORT PACKAGES
from PIL import Image
import yaml
import os

from torchvision import transforms
from torch.utils import data

img_size = 224

data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize([img_size, img_size]),
transforms.Normalize(mean=[.5], std=[.5])
])
# load config
with open('./configs/config.yaml', 'r', encoding='utf-8') as f:
yaml_info = yaml.load(f.read(), Loader=yaml.FullLoader)
IMAGE_SIZE = yaml_info['image_size']
MEAN = yaml_info['mean']
STD = yaml_info['std']


# resize 64x64 images
Expand All @@ -27,8 +27,8 @@ class ImageTransform():
def __init__(self):
self.data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize([img_size, img_size]),
transforms.Normalize(mean=[0.00924097], std=[0.00282327])
transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
transforms.Normalize(mean=MEAN, std=STD)
])

def __call__(self, img):
Expand Down Expand Up @@ -64,37 +64,28 @@ def __init__(self, root1, root2, mode='train', transform=None):
# dataPath2 = {'train': self.train_path2, 'val': self.val_path2, 'test': self.test_path2}

# class1: normal images and label
# real images
# real images (normal)
normal_path1 = os.path.join(dataPath1[self.mode], 'normal')

# modify the ratio of the test images
if self.mode == 'test':
normal_img_list1 = os.listdir(normal_path1)
else:
normal_img_list1 = os.listdir(normal_path1)
normal_img_list1 = os.listdir(normal_path1)
normal_img_list1_ = [os.path.join(normal_path1, _) for _ in normal_img_list1]
normal_label_list1_ = [0] * len(normal_img_list1)

# normal_img_list_ = normal_img_list1_
# normal_label_list_ = normal_label_list1_

# fake images
# fake images (normal)
normal_path2 = os.path.join(self.root2, 'normal', self.mode)
normal_img_list2 = os.listdir(normal_path2)

normal_img_list2_ = [os.path.join(normal_path2, _) for _ in normal_img_list2]
normal_label_list2_ = [0] * len(normal_img_list2)

# add
# real + fake (normal)
normal_img_list_ = normal_img_list1_ + normal_img_list2_
normal_label_list_ = normal_label_list1_ + normal_label_list2_

# class2: abnormal images and label
pneumonia_path = os.path.join(dataPath1[self.mode], 'pneumonia')
if self.mode == 'test':
pneumonia_img_list = os.listdir(pneumonia_path)[:376]
else:
pneumonia_img_list = os.listdir(pneumonia_path)
pneumonia_img_list = os.listdir(pneumonia_path)
pneumonia_img_list_ = [os.path.join(pneumonia_path, _) for _ in pneumonia_img_list]
pneumonia_label_list_ = [1] * len(pneumonia_img_list)

Expand Down Expand Up @@ -123,13 +114,17 @@ def __getitem__(self, index):


if __name__ == "__main__":
train_dataset = PneumoniaDataset(root1='../data/real', root2='../data/fake',
train_dataset = PneumoniaDataset(root1='..\\data\\real', root2='..\\data\\fake',
mode='train', transform=ImageTransform())

print(len(train_dataset))
val_loader = data.DataLoader(
train_dataset, batch_size=800, shuffle=True)

batch_iter = iter(val_loader)
data = next(batch_iter)
print(data[1])
val_dataset = PneumoniaDataset(root1='..\\data\\real', root2='..\\data\\fake',
mode='val', transform=ImageTransform())

print(len(val_dataset))

test_dataset = PneumoniaDataset(root1='..\\data\\real', root2='..\\data\\fake',
mode='test', transform=ImageTransform())

print(len(test_dataset))
6 changes: 3 additions & 3 deletions utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def copy_files_to_output_dir(files, dataset_dir, output_dir, sub_dir):
shutil.copy(src, dst)


def split_dataset(dataset_dir: str, output_dir: str, train_percent: float, val_percent: float, test_percent: float):
def split_dataset(dataset_dir, output_dir, train_percent, val_percent, test_percent):
"""
Split the dataset into train, validation, and test sets and copy them to the output directory.
Expand Down Expand Up @@ -64,7 +64,7 @@ def split_dataset(dataset_dir: str, output_dir: str, train_percent: float, val_p


if __name__ == '__main__':
input_dataset_dir = '../data/fake/normal1/'
output_dataset_dir = '../data/fake/normal'
input_dataset_dir = '..\\data\\fake\\normal'
output_dataset_dir = '..\\data\\fake\\normal'

split_dataset(input_dataset_dir, output_dataset_dir, 0.8475, 0.1, 0.05)

0 comments on commit 3275339

Please sign in to comment.