-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path01_modules.py
28 lines (25 loc) · 811 Bytes
/
01_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
class OurModule(nn.Module):
def __init__(self, num_inputs, num_classes, dropout_prob=0.3):
super(OurModule, self).__init__()
self.pipe = nn.Sequential(
nn.Linear(num_inputs, 5),
nn.ReLU(),
nn.Linear(5, 20),
nn.ReLU(),
nn.Linear(20, num_classes),
nn.Dropout(p=dropout_prob),
nn.Softmax(dim=1)
)
def forward(self, x):
return self.pipe(x)
if __name__ == "__main__":
net = OurModule(num_inputs=2, num_classes=3)
print(net)
v = torch.FloatTensor([[2, 3]])
out = net(v)
print(out)
print("Cuda's availability is %s" % torch.cuda.is_available())
if torch.cuda.is_available():
print("Data from cuda: %s" % out.to('cuda'))