forked from AliaksandrSiarohin/cuda-gridsample-grad2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda_gridsample.py
133 lines (102 loc) · 5.56 KB
/
cuda_gridsample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from torch.utils.cpp_extension import load
import torch
from pkg_resources import parse_version
import os
gridsample_grad2 = load(name='gridsample_grad2', sources=[os.path.join(os.path.dirname(__file__), 'gridsample_cuda.cpp'), os.path.join(os.path.dirname(__file__), 'gridsample_cuda.cu')], verbose=True)
def grid_sample(input, grid, mode, padding_mode='zeros', align_corners=True):
assert mode == 'bilinear'
if input.dim() == 4:
return grid_sample_2d(input, grid, padding_mode, align_corners)
elif input.dim() == 5:
return grid_sample_3d(input, grid, padding_mode, align_corners)
else:
raise NotImplementedError()
def grid_sample_2d(input, grid, padding_mode='zeros', align_corners=True):
assert padding_mode in ['zeros', 'border']
return _GridSample2dForward.apply(input, grid, padding_mode, align_corners)
def grid_sample_3d(input, grid, padding_mode='zeros', align_corners=True):
assert padding_mode in ['zeros', 'border']
return _GridSample3dForward.apply(input, grid, padding_mode, align_corners)
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a')
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
assert input.ndim == 4
assert grid.ndim == 4
assert input.shape[0] == grid.shape[0]
assert grid.shape[3] == 2
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
padding_mode=padding_mode, align_corners=align_corners)
ctx.save_for_backward(input, grid)
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
ctx.align_corners = align_corners
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
return grad_input, grad_grid, None, None
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
op, _ = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
ctx.save_for_backward(grad_output, input, grid)
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
grad_output, input, grid = ctx.saved_tensors
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
out = gridsample_grad2.grad2_2d(grad2_grad_input, grad2_grad_grid, grad_output,
input, grid, ctx.padding_mode, ctx.align_corners)
grad_grad_output = out[0]
grad_input = out[1]
grad_grid = out[2]
return grad_grad_output, grad_input, grad_grid, None, None
class _GridSample3dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
assert input.ndim == 5
assert grid.ndim == 5
assert input.shape[0] == grid.shape[0]
assert grid.shape[4] == 3
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
padding_mode=padding_mode, align_corners=align_corners)
ctx.save_for_backward(input, grid)
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
ctx.align_corners = align_corners
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample3dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
return grad_input, grad_grid, None, None
class _GridSample3dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
op, _ = torch._C._jit_get_operation('aten::grid_sampler_3d_backward')
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
ctx.save_for_backward(grad_output, input, grid)
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
grad_output, input, grid = ctx.saved_tensors
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
out = gridsample_grad2.grad2_3d(grad2_grad_input, grad2_grad_grid, grad_output,
input, grid, ctx.padding_mode, ctx.align_corners)
grad_grad_output = out[0]
grad_input = out[1]
grad_grid = out[2]
return grad_grad_output, grad_input, grad_grid, None, None