-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneuralnetwork.py
87 lines (75 loc) · 2.84 KB
/
neuralnetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
convnetworkoutputsize = 256
convoutputsize = 8192
class ConvNetwork(nn.Module):
def __init__(self):
super().__init__()#NeuralNetwork, self
self.conv1 = nn.Conv2d(3, 32, 3, padding=0)
self.conv2 = nn.Conv2d(32, 64, 3)
self.conv3 = nn.Conv2d(64, 128, 3)
#self.conv4 = nn.Conv2d(128, 128, 3)
#self.conv5 = nn.Conv2d(128, 128, 3)
self.fc1 = nn.Linear(convoutputsize, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, convnetworkoutputsize)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, state):
x = self.pool(F.relu(self.conv1(state)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
#x = self.conv4(x)
x = x.view(x.shape[0], -1) #refit x
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class StateValueNetwork(nn.Module):
def __init__(self):
super().__init__()#NeuralNetwork, self
self.fc1 = nn.Linear(convnetworkoutputsize, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class ActionValueNetwork(nn.Module):
def __init__(self):
super().__init__()#NeuralNetwork, self
self.fc1 = nn.Linear(convnetworkoutputsize+11, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
# uniform init layer 3
def forward(self, state, action):
x = F.relu(self.fc1(torch.cat((state, action), dim=1)))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class PolicyNetwork(nn.Module):
def __init__(self):
super().__init__()#NeuralNetwork, self
self.fc1 = nn.Linear(convnetworkoutputsize, 128)
self.fc2 = nn.Linear(128, 64)
self.mean_fc = nn.Linear(64, 11)
# mean, init uniform
self.log_variance_fc = nn.Linear(64, 11)
# variance, init uniform
# mean, variance, normal distribution
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
mean = self.mean_fc(x)
log_variance = self.log_variance_fc(x)
log_variance = torch.clamp(log_variance, -20, 2)
return mean, log_variance
def sample(self, state, epsilon=1e-6):
mean, log_variance = self.forward(state)
variance = log_variance.exp()
gaussian = Normal(mean, variance)
z = gaussian.sample()
log_pi = (gaussian.log_prob(z) - torch.log(1 - (torch.tanh(z)).pow(2) + epsilon)).sum(1, keepdim=True)
return mean, variance, z, log_pi