-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmean.py
42 lines (35 loc) · 1.17 KB
/
mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from __future__ import with_statement
import numpy as np
import os
import glob
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import glob
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
def mean__std(data_loader):
cnt = 0
mean = torch.empty(3)
std = torch.empty(3)
# import pdb;
# pdb.set_trace()
for data, label in data_loader:
b, c, h, w = data.size()
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
mean = (cnt * mean + sum_) / (cnt + nb_pixels)
std = (cnt * std + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return mean, torch.sqrt(std - mean ** 2)
train_data = torchvision.datasets.ImageFolder('b', transform=transforms.Compose([transforms.ToTensor()]))
data_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False, num_workers=4)
mean, std = mean__std(data_loader)
print(mean, std)