-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodulation.py
168 lines (125 loc) · 4.77 KB
/
modulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
'''
Code from : https://github.com/lucidrains/siren-pytorch/blob/master/siren_pytorch/siren_pytorch.py
'''
import math
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
# helpers
def exists(val):
return val is not None
def cast_tuple(val, repeat = 1):
return val if isinstance(val, tuple) else ((val,) * repeat)
# sin activation
class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
# siren layer
class Siren(nn.Module):
def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c = c, w0 = w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
def init_(self, weight, bias, c, w0):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if exists(bias):
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
return out
# siren network
class SirenNet(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
self.layers.append(Siren(
dim_in = layer_dim_in,
dim_out = dim_hidden,
w0 = layer_w0,
use_bias = use_bias,
is_first = is_first
))
final_activation = nn.Identity() if not exists(final_activation) else final_activation
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)
def forward(self, x, mods = None):
mods = cast_tuple(mods, self.num_layers)
for layer, mod in zip(self.layers, mods):
x = layer(x)
if exists(mod):
x *= mod #rearrange(mod, 'd -> () d')
return self.last_layer(x)
# modulatory feed forward
class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else (dim_hidden + dim_in)
self.layers.append(nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.LeakyReLU()
))
self.weight_init = init_weights_normal
self.layers.apply(self.weight_init)
def forward(self, z):
x = z
hiddens = []
for layer in self.layers:
x = layer(x)
hiddens.append(x)
x = torch.cat((x, z), dim=1)
return tuple(hiddens)
# wrapper
class SirenWrapper(nn.Module):
def __init__(self, net, latent_dim = None):
super().__init__()
self.net = net
self.modulator = None
if exists(latent_dim):
self.modulator = Modulator(
dim_in = latent_dim,
dim_hidden = net.dim_hidden,
num_layers = net.num_layers
)
def forward(self, coords, latent = None):
modulate = exists(self.modulator)
assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'
mods = self.modulator(latent) if modulate else None
out = self.net(coords, mods)
return out
########################
# Initialization methods
def init_weights_normal(m):
if type(m) == nn.Linear:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
def sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
def first_layer_sine_init(m):
with torch.no_grad():
if hasattr(m, 'weight'):
num_input = m.weight.size(-1)
m.weight.uniform_(-1 / num_input, 1 / num_input)