-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
29 lines (21 loc) · 944 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torch import optim
from load_data import get_dataset
from model import ConvNet
from train import *
from constants import device
if __name__ == '__main__':
model = ConvNet()
model.to(device)
print(model)
dataset = get_dataset()
data_train, data_val = random_split(dataset, [0.85, 0.15])
print(f'Training samples:{len(data_train)}\nValidation samples:{len(data_val)}\n')
train_loader = DataLoader(data_train, 8, shuffle=True)
val_loader = DataLoader(data_val, 8, shuffle=True)
train_interactive(optim.SGD(model.parameters(), lr=1e-3, momentum=0.9), model, nn.CrossEntropyLoss(), train_loader,
val_loader, 20, 0.95)
# train_epochs(18, optim.SGD(model.parameters(), lr=1e-3, momentum=0.9), model, nn.CrossEntropyLoss(), train_loader, val_loader)