forked from AliaksandrSiarohin/cuda-gridsample-grad2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnaive_gridsample.py
168 lines (125 loc) · 5.99 KB
/
naive_gridsample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
def grid_sample_3d(image, optical):
N, C, ID, IH, IW = image.shape
_, D, H, W, _ = optical.shape
ix = optical[..., 0]
iy = optical[..., 1]
iz = optical[..., 2]
ix = ((ix + 1) / 2) * (IW - 1);
iy = ((iy + 1) / 2) * (IH - 1);
iz = ((iz + 1) / 2) * (ID - 1);
with torch.no_grad():
ix_tnw = torch.floor(ix);
iy_tnw = torch.floor(iy);
iz_tnw = torch.floor(iz);
ix_tne = ix_tnw + 1;
iy_tne = iy_tnw;
iz_tne = iz_tnw;
ix_tsw = ix_tnw;
iy_tsw = iy_tnw + 1;
iz_tsw = iz_tnw;
ix_tse = ix_tnw + 1;
iy_tse = iy_tnw + 1;
iz_tse = iz_tnw;
ix_bnw = ix_tnw;
iy_bnw = iy_tnw;
iz_bnw = iz_tnw + 1;
ix_bne = ix_tnw + 1;
iy_bne = iy_tnw;
iz_bne = iz_tnw + 1;
ix_bsw = ix_tnw;
iy_bsw = iy_tnw + 1;
iz_bsw = iz_tnw + 1;
ix_bse = ix_tnw + 1;
iy_bse = iy_tnw + 1;
iz_bse = iz_tnw + 1;
tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
with torch.no_grad():
torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw)
torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw)
torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw)
torch.clamp(ix_tne, 0, IW - 1, out=ix_tne)
torch.clamp(iy_tne, 0, IH - 1, out=iy_tne)
torch.clamp(iz_tne, 0, ID - 1, out=iz_tne)
torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw)
torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw)
torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw)
torch.clamp(ix_tse, 0, IW - 1, out=ix_tse)
torch.clamp(iy_tse, 0, IH - 1, out=iy_tse)
torch.clamp(iz_tse, 0, ID - 1, out=iz_tse)
torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)
torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)
torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)
torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)
image = image.view(N, C, ID * IH * IW)
tnw_val = torch.gather(image, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tne_val = torch.gather(image, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1))
tsw_val = torch.gather(image, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tse_val = torch.gather(image, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1))
bnw_val = torch.gather(image, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bne_val = torch.gather(image, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
bsw_val = torch.gather(image, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bse_val = torch.gather(image, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))
out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))
return out_val
def grid_sample_2d(image, optical):
N, C, IH, IW = image.shape
_, H, W, _ = optical.shape
ix = optical[..., 0]
iy = optical[..., 1]
ix = ((ix + 1) / 2) * (IW-1);
iy = ((iy + 1) / 2) * (IH-1);
with torch.no_grad():
ix_nw = torch.floor(ix);
iy_nw = torch.floor(iy);
ix_ne = ix_nw + 1;
iy_ne = iy_nw;
ix_sw = ix_nw;
iy_sw = iy_nw + 1;
ix_se = ix_nw + 1;
iy_se = iy_nw + 1;
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
with torch.no_grad():
torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
torch.clamp(iy_nw, 0, IH-1, out=iy_nw)
torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
torch.clamp(ix_se, 0, IW-1, out=ix_se)
torch.clamp(iy_se, 0, IH-1, out=iy_se)
image = image.view(N, C, IH * IW)
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
se_val.view(N, C, H, W) * se.view(N, 1, H, W))
return out_val