-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
124 lines (101 loc) · 4.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: [email protected]
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
img = Image.open(filename).convert('RGB')
if size is not None:
if keep_asp:
size2 = int(size * 1.0 / img.size[0] * img.size[1])
img = img.resize((size, size2), Image.ANTIALIAS)
else:
img = img.resize((size, size), Image.ANTIALIAS)
elif scale is not None:
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
img = np.array(img).transpose(2, 0, 1)
img = torch.from_numpy(img).float()
return img
def tensor_save_rgbimage(tensor, filename, cuda=False):
if cuda:
img = tensor.clone().cpu().clamp(0, 255).numpy()
else:
img = tensor.clone().clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)
return img
def tensor_save_bgrimage(tensor, filename, cuda=False):
(b, g, r) = torch.chunk(tensor, 3)
tensor = torch.cat((r, g, b))
tensor_save_rgbimage(tensor, filename, cuda)
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def subtract_imagenet_mean_batch(batch):
"""Subtract ImageNet mean pixel-wise from a BGR image."""
tensortype = type(batch.data)
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
return batch - Variable(mean)
def add_imagenet_mean_batch(batch):
"""Add ImageNet mean pixel-wise from a BGR image."""
tensortype = type(batch.data)
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
return batch + Variable(mean)
def imagenet_clamp_batch(batch, low, high):
batch[:,0,:,:].data.clamp_(low-103.939, high-103.939)
batch[:,1,:,:].data.clamp_(low-116.779, high-116.779)
batch[:,2,:,:].data.clamp_(low-123.680, high-123.680)
def preprocess_batch(batch):
batch = batch.transpose(0, 1)
(r, g, b) = torch.chunk(batch, 3)
batch = torch.cat((b, g, r))
batch = batch.transpose(0, 1)
return batch
def init_vgg16(model_folder):
"""load the vgg16 model feature"""
if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
os.system(
'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7'))
vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
vgg = Vgg16()
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
dst.data[:] = src
torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
class StyleLoader():
def __init__(self, style_folder, style_size, cuda=True):
self.folder = style_folder
self.style_size = style_size
self.files = os.listdir(style_folder)
self.cuda = cuda
def get(self, i):
idx = i%len(self.files)
filepath = os.path.join(self.folder, self.files[idx])
style = tensor_load_rgbimage(filepath, self.style_size)
style = style.unsqueeze(0)
style = preprocess_batch(style)
if self.cuda:
style = style.cuda()
style_v = Variable(style, requires_grad=False)
return style_v
def size(self):
return len(self.files)