-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathimage.py
73 lines (60 loc) · 2.86 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
""" Encoders and decoders specific to tasks that operate over images. """
import torch
import torchvision.transforms as transforms
from coders.coder import Encoder, Decoder
import util.util
class ConcatenationEncoder(Encoder):
"""
Concatenates `k` images into a single image. This class is currently only
defined for `k = 2` and `k = 4`. For example, given `k = 2` 32 x 32
(height x width) input images, this encoder downsamples each image to
be 32 x 16 pixels in size, and then concatenate the two downsampled images
side-by-side horizontally. Given `k = 4` 32 x 32 images, each image is
downsampled to be 16 x 16 pixels in size and placed in quadrants of a
resultant parity image.
"""
def __init__(self, ec_k, ec_r, in_dim):
super().__init__(ec_k, ec_r, in_dim)
if ec_k != 2 and ec_k != 4:
raise Exception(
"ConcatenationEncoder currently supports values of `ec_k`of 2 or 4.")
self.original_height = self.in_dim[2]
self.original_width = self.in_dim[3]
if (self.original_height % 2 != 0) or (self.original_width % 2 != 0):
raise Exception(
"ConcatenationEncoder requires that image height and "
"width be divisible by 2. Image received with shape: "
+ str(self.in_dim))
if ec_k == 2:
self.resized_height = self.original_height
self.resized_width = self.original_width // 2
else:
# `ec_k` = 4
self.resized_height = self.original_height // 2
self.resized_width = self.original_width // 2
def forward(self, in_data):
batch_size = in_data.size(0)
# Initialize a batch of parities to a tensor of all zeros
out = util.util.try_cuda(
torch.zeros(batch_size, 1,
self.original_height, self.original_width))
reshaped = in_data.view(-1, self.ec_k,
self.resized_height, self.resized_width)
if self.ec_k == 2:
out[:, :, :, :self.resized_width] = reshaped[:, 0].unsqueeze(1)
out[:, :, :, self.resized_width:] = reshaped[:, 1].unsqueeze(1)
else:
# `ec_k` = 4
out[:, :, :self.resized_height, :self.resized_width] = reshaped[:, 0].unsqueeze(1)
out[:, :, :self.resized_height, self.resized_width:] = reshaped[:, 1].unsqueeze(1)
out[:, :, self.resized_height:, :self.resized_width] = reshaped[:, 2].unsqueeze(1)
out[:, :, self.resized_height:, self.resized_width:] = reshaped[:, 3].unsqueeze(1)
return out
def resize_transform(self):
"""
Returns
-------
A tranform that resizes images to be the size needed for
concatenation.
"""
return transforms.Resize((self.resized_height, self.resized_width))