-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
121 lines (88 loc) · 3.3 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import pandas as pd
import torch
import torchvision.io
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import constants
'''
class ToRGB():
def __call__(self, sample):
if sample.size(0) == 1:
sample = sample.repeat(3, 1, 1)
elif sample.size(0) >= 4:
sample = sample[:3]
return sample
'''
class ImageDataset(Dataset):
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(constants.DEFAULT_IMG_SIZE , antialias=True),
transforms.Normalize(mean=constants.MEAN, std=constants.STD),
])
def __init__(self, img_dir, transform=default_transform, target_transform=None):
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
self.data = self.collect_images()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = self.data.iloc[idx, 0]
label = int(self.data.iloc[idx, 1])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def collect_images(self):
img_paths = []
labels = []
for i_category, category in enumerate(constants.TARGET_CATEGORIES):
img_dir_path = os.path.join(self.img_dir, category)
new_img_names = os.listdir(img_dir_path)
new_paths = [os.path.join(img_dir_path, name) for name in new_img_names if
name.endswith(('.jpg', '.png', '.jpeg'))]
img_paths += new_paths
labels += [i_category] * len(new_paths)
df = pd.DataFrame(data={'path': img_paths, 'label': labels})
return df
def get_dataset():
return ImageDataset(constants.DATASET_PATH)
def compute_stats(dataset):
loader = DataLoader(dataset,
batch_size=128,
shuffle=False)
img_size = constants.DEFAULT_IMG_SIZE
n_total_pixels = len(dataset) * img_size[0] * img_size[1]
r_sum = g_sum = b_sum = 0
r_variance_sum = g_variance_sum = b_variance_sum = 0
for batch, _ in loader:
r_sum += batch[:, 0].sum()
g_sum += batch[:, 1].sum()
b_sum += batch[:, 2].sum()
mean_r = r_sum / n_total_pixels
mean_g = g_sum / n_total_pixels
mean_b = b_sum / n_total_pixels
for batch, _ in loader:
r_variance_sum += ((batch[:, 0] - mean_r).pow(2)).sum()
g_variance_sum += ((batch[:, 1] - mean_g).pow(2)).sum()
b_variance_sum += ((batch[:, 2] - mean_b).pow(2)).sum()
mean = torch.tensor([mean_r, mean_g, mean_b])
std = torch.tensor([torch.sqrt(r_variance_sum / n_total_pixels),
torch.sqrt(g_variance_sum / n_total_pixels),
torch.sqrt(b_variance_sum / n_total_pixels)])
return mean, std
# code section that computes mean and variance for the dataset
'''
if __name__ == '__main__':
ds = get_dataset()
print(compute_stats(ds))
loader = DataLoader(ds,
batch_size=32,
shuffle=False)
for data in loader:
print(data[0])
'''