forked from netabecker/Stegastamp_pytorch_version
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
50 lines (41 loc) · 1.82 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import numpy as np
from glob import glob
from PIL import Image, ImageOps
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
class StegaData(Dataset):
def __init__(self, data_path, secret_size=100, size=(400, 400), subset_size=None):
self.data_path = data_path
self.secret_size = secret_size
self.size = size
self.files_list = glob(f'{self.data_path}/images*/images/*/*.jpg')
if subset_size:
self.files_list = self.files_list[:subset_size]
self.to_tensor = transforms.ToTensor()
def __getitem__(self, idx):
img_cover_path = self.files_list[idx]
img_cover = Image.open(img_cover_path).convert('RGB')
img_cover = ImageOps.fit(img_cover, self.size)
img_cover = self.to_tensor(img_cover)
# to_tensor already normalizes in between [0,1]
# img_cover = np.array(img_cover, dtype=np.float32) / 255.
secret = np.random.binomial(1, 0.5, self.secret_size)
secret = torch.from_numpy(secret).float()
return img_cover, secret
def __len__(self):
return len(self.files_list)
# if __name__ == '__main__':
# # dataset = StegaData(data_path='F:\\VOCdevkit\\VOC2012\\JPEGImages')
# # print(len(dataset))
# # img_cover, secret = dataset[10]
# # print(type(img_cover), type(secret))
# # print(img_cover.shape, secret.shape)
#
# dataset = StegaData(data_path=r'E:\dataset\mirflickr', secret_size=100, size=(400, 400))
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, pin_memory=True)
# image_input, secret_input = next(iter(dataloader))
# print(type(image_input), type(secret_input))
# print(image_input.shape, secret_input.shape)
# print(image_input.max())