forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSampleBicubic2d.cpp
315 lines (271 loc) · 9.43 KB
/
UpSampleBicubic2d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/UpSample.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
static void upsample_bicubic2d_out_frame(
scalar_t* odata,
scalar_t* idata,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
int64_t nbatch,
int64_t channels,
bool align_corners) {
// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
for (int64_t output_y = 0; output_y < output_height; output_y++) {
for (int64_t output_x = 0; output_x < output_width; output_x++) {
const scalar_t* in = &idata[output_y * input_width + output_x];
scalar_t* out = &odata[output_y * output_width + output_x];
for (int64_t c = 0; c < channels; ++c) {
out[0] = in[0];
in += input_width * input_height;
out += output_width * output_height;
}
}
}
return;
}
// Bicubic interpolation
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners);
for (int64_t output_y = 0; output_y < output_height; output_y++) {
for (int64_t output_x = 0; output_x < output_width; output_x++) {
scalar_t* in = idata;
scalar_t* out = odata;
const scalar_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x = floorf(real_x);
const scalar_t t_x = real_x - input_x;
const scalar_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = floorf(real_y);
const scalar_t t_y = real_y - input_y;
for (int64_t c = 0; c < channels * nbatch; c++) {
scalar_t coefficients[4];
// Interpolate 4 times in the x direction
for (int64_t i = 0; i < 4; i++) {
coefficients[i] = cubic_interp1d<scalar_t>(
upsample_get_value_bounded<scalar_t>(
in, input_width, input_height, input_x - 1, input_y - 1 + i),
upsample_get_value_bounded<scalar_t>(
in, input_width, input_height, input_x + 0, input_y - 1 + i),
upsample_get_value_bounded<scalar_t>(
in, input_width, input_height, input_x + 1, input_y - 1 + i),
upsample_get_value_bounded<scalar_t>(
in, input_width, input_height, input_x + 2, input_y - 1 + i),
t_x);
}
// Interpolate in the y direction using x interpolations
out[output_y * output_width + output_x] = cubic_interp1d<scalar_t>(
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
t_y);
// Move to next channel
in += input_width * input_height;
out += output_width * output_height;
}
}
}
}
template <typename scalar_t>
static void upsample_bicubic2d_backward_out_frame(
scalar_t* odata,
scalar_t* idata,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
int64_t nbatch,
int64_t channels,
bool align_corners) {
channels = channels * nbatch;
// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
for (int64_t output_y = 0; output_y < output_height; output_y++) {
for (int64_t output_x = 0; output_x < output_width; output_x++) {
scalar_t* in = &idata[output_y * input_width + output_x];
scalar_t* out = &odata[output_y * output_width + output_x];
for (int64_t c = 0; c < channels; ++c) {
in[0] = out[0];
in += input_width * input_height;
out += output_width * output_height;
}
}
}
return;
}
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners);
for (int64_t output_y = 0; output_y < output_height; output_y++) {
for (int64_t output_x = 0; output_x < output_width; output_x++) {
scalar_t* in = idata;
scalar_t* out = odata;
const scalar_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x = floorf(real_x);
scalar_t t_x = real_x - input_x;
const scalar_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = floorf(real_y);
scalar_t t_y = real_y - input_y;
scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
get_cubic_upsample_coefficients<scalar_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<scalar_t>(y_coeffs, t_y);
for (int64_t c = 0; c < channels; c++) {
scalar_t out_value = out[output_y * output_width + output_x];
for (int64_t i = 0; i < 4; i++) {
for (int64_t j = 0; j < 4; j++) {
upsample_increment_value_bounded<scalar_t>(
in,
input_width,
input_height,
input_x - 1 + i,
input_y - 1 + j,
out_value * y_coeffs[j] * x_coeffs[i]);
}
}
in += input_width * input_height;
out += output_width * output_height;
}
}
}
}
static void upsample_bicubic2d_out_cpu_template(
Tensor& output,
const Tensor& input_,
IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t nbatch = input_.size(0);
int64_t channels = input_.size(1);
int64_t input_height = input_.size(2);
int64_t input_width = input_.size(3);
upsample_2d_shape_check(
input_,
Tensor(),
nbatch,
channels,
input_height,
input_width,
output_height,
output_width);
auto input = input_.contiguous();
output.resize_({nbatch, channels, output_height, output_width});
output.zero_();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_bicubic2d", [&] {
auto* idata = input.data<scalar_t>();
auto* odata = output.data<scalar_t>();
upsample_bicubic2d_out_frame<scalar_t>(
odata,
idata,
input_height,
input_width,
output_height,
output_width,
nbatch,
channels,
align_corners);
});
}
static void upsample_bicubic2d_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 4,
"It is expected input_size equals to 4, but got size ",
input_size.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_height = input_size[2];
int64_t input_width = input_size[3];
upsample_2d_shape_check(
Tensor(),
grad_output_,
nbatch,
channels,
input_height,
input_width,
output_height,
output_width);
auto grad_output = grad_output_.contiguous();
grad_input.resize_({nbatch, channels, input_height, input_width});
grad_input.zero_();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_bicubic2d_backward", [&] {
scalar_t* idata = grad_input.data<scalar_t>();
scalar_t* odata = grad_output.data<scalar_t>();
upsample_bicubic2d_backward_out_frame<scalar_t>(
odata,
idata,
input_height,
input_width,
output_height,
output_width,
nbatch,
channels,
align_corners);
});
}
} // namespace
Tensor& upsample_bicubic2d_out_cpu(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
bool align_corners) {
upsample_bicubic2d_out_cpu_template(
output, input, output_size, align_corners);
return output;
}
Tensor upsample_bicubic2d_cpu(
const Tensor& input,
IntArrayRef output_size,
bool align_corners) {
auto output = at::empty({0}, input.options());
upsample_bicubic2d_out_cpu_template(
output, input, output_size, align_corners);
return output;
}
Tensor& upsample_bicubic2d_backward_out_cpu(
Tensor& grad_input,
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners) {
upsample_bicubic2d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size, align_corners);
return grad_input;
}
Tensor upsample_bicubic2d_backward_cpu(
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners) {
auto grad_input = at::zeros(input_size, grad_output.options());
upsample_bicubic2d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size, align_corners);
return grad_input;
}
} // namespace native
} // namespace at