-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNCC.py
139 lines (108 loc) · 5.41 KB
/
NCC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Normalized Cross-Correlation for pattern matching.
pytorch implementation
CVLab EPFL 2019
"""
import logging
import torch
from torch.nn import functional as F
ncc_logger = logging.getLogger(__name__)
def patch_mean(images, patch_shape):
"""
Computes the local mean of an image or set of images.
Args:
images (Tensor): Expected size is (n_images, n_channels, *image_size). 1d, 2d, and 3d images are accepted.
patch_shape (tuple): shape of the patch tensor (n_channels, *patch_size)
Returns:
Tensor same size as the image, with local means computed independently for each channel.
Example::
>>> images = torch.randn(4, 3, 15, 15) # 4 images, 3 channels, 15x15 pixels each
>>> patch_shape = 3, 5, 5 # 3 channels, 5x5 pixels neighborhood
>>> means = patch_mean(images, patch_shape)
>>> expected_mean = images[3, 2, :5, :5].mean() # mean of the third image, channel 2, top left 5x5 patch
>>> computed_mean = means[3, 2, 5//2, 5//2] # computed mean whose 5x5 neighborhood covers same patch
>>> computed_mean.isclose(expected_mean).item()
1
"""
channels, *patch_size = patch_shape
dimensions = len(patch_size)
padding = tuple(side // 2 for side in patch_size)
conv = (F.conv1d, F.conv2d, F.conv3d)[dimensions - 1]
# Convolution with these weights will effectively compute the channel-wise means
patch_elements = torch.Tensor(patch_size).prod().item()
weights = torch.full((channels, channels, *patch_size), fill_value=1 / patch_elements)
weights = weights.to(images.device)
# Make convolution operate on single channels
channel_selector = torch.eye(channels).bool()
weights[~channel_selector] = 0
result = conv(images, weights, padding=padding, bias=None)
return result
def patch_std(image, patch_shape):
"""
Computes the local standard deviations of an image or set of images.
Args:
images (Tensor): Expected size is (n_images, n_channels, *image_size). 1d, 2d, and 3d images are accepted.
patch_shape (tuple): shape of the patch tensor (n_channels, *patch_size)
Returns:
Tensor same size as the image, with local standard deviations computed independently for each channel.
Example::
>>> images = torch.randn(4, 3, 15, 15) # 4 images, 3 channels, 15x15 pixels each
>>> patch_shape = 3, 5, 5 # 3 channels, 5x5 pixels neighborhood
>>> stds = patch_std(images, patch_shape)
>>> patch = images[3, 2, :5, :5]
>>> expected_std = patch.std(unbiased=False) # standard deviation of the third image, channel 2, top left 5x5 patch
>>> computed_std = stds[3, 2, 5//2, 5//2] # computed standard deviation whose 5x5 neighborhood covers same patch
>>> computed_std.isclose(expected_std).item()
1
"""
return (patch_mean(image**2, patch_shape) - patch_mean(image, patch_shape)**2).sqrt()
def channel_normalize(template):
"""
Z-normalize image channels independently.
"""
reshaped_template = template.clone().view(template.shape[0], -1)
reshaped_template.sub_(reshaped_template.mean(dim=-1, keepdim=True))
reshaped_template.div_(reshaped_template.std(dim=-1, keepdim=True, unbiased=False))
return reshaped_template.view_as(template)
class NCC(torch.nn.Module):
"""
Computes the [Zero-Normalized Cross-Correlation][1] between an image and a template.
Example:
>>> lena_path = "https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png"
>>> lena_tensor = torch.Tensor(plt.imread(lena_path)).permute(2, 0, 1).cuda()
>>> patch_center = 275, 275
>>> y1, y2 = patch_center[0] - 25, patch_center[0] + 25
>>> x1, x2 = patch_center[1] - 25, patch_center[1] + 25
>>> lena_patch = lena_tensor[:, y1:y2 + 1, x1:x2 + 1]
>>> ncc = NCC(lena_patch)
>>> ncc_response = ncc(lena_tensor[None, ...])
>>> ncc_response.max()
tensor(1.0000, device='cuda:0')
>>> np.unravel_index(ncc_response.argmax(), lena_tensor.shape)
(0, 275, 275)
[1]: https://en.wikipedia.org/wiki/Cross-correlation#Zero-normalized_cross-correlation_(ZNCC)
"""
def __init__(self, template, keep_channels=False):
super().__init__()
self.keep_channels = keep_channels
channels, *template_shape = template.shape
dimensions = len(template_shape)
self.padding = tuple(side // 2 for side in template_shape)
self.conv_f = (F.conv1d, F.conv2d, F.conv3d)[dimensions - 1]
self.normalized_template = channel_normalize(template)
ones = template.dim() * (1, )
self.normalized_template = self.normalized_template.repeat(channels, *ones)
# Make convolution operate on single channels
channel_selector = torch.eye(channels).bool()
self.normalized_template[~channel_selector] = 0
# Reweight so that output is averaged
patch_elements = torch.Tensor(template_shape).prod().item()
self.normalized_template.div_(patch_elements)
def forward(self, image):
result = self.conv_f(image, self.normalized_template, padding=self.padding, bias=None)
std = patch_std(image, self.normalized_template.shape[1:])
result.div_(std)
if not self.keep_channels:
result = result.mean(dim=1)
return result