-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
66 lines (53 loc) · 2.14 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from torch.utils.data import Dataset
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from PIL import Image
import numpy as np
from torchbearer import deep_to
import torchbearer
def custom_loader(state):
img, phow_feature, label = deep_to(next(
state[torchbearer.ITERATOR]), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE])
batch_size = img.shape[0]
tmp = phow_feature.resize_(batch_size, 1, 224, 224)
tmp2 = torch.cat([img, tmp], 1)
state[torchbearer.X], state[torchbearer.Y_TRUE] = tmp2, label
class CourseworkDataset(Dataset):
"""A Coursework dataset."""
def __init__(self, transform, mode, path2phow_features):
super(CourseworkDataset, self).__init__()
self.transform = transform
self.mode = mode
self.path2phow_features = path2phow_features
if mode == "predict":
predict_list = pd.read_csv('test_list.csv')
self.images = predict_list['path'].tolist()
self.labels = predict_list['predict_label'].tolist()
else:
train_list = pd.read_csv('train_list.csv')
train, test = train_test_split(
train_list, test_size=0.33, random_state=42)
if mode == "train":
self.images = train['path'].tolist()
self.labels = train['label'].tolist()
elif mode == "test":
self.images = test['path'].tolist()
self.labels = test['label'].tolist()
else:
raise ValueError
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img, label = self.images[idx], self.labels[idx]
phow_feature = self.path2phow_features[img][0]
phow_feature = torch.tensor(phow_feature)
image = Image.open(img).convert("RGB")
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
img = augmented['image']
label = torch.tensor(label)
return img, phow_feature, label