forked from FuNian788/Pytorch-BMN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
132 lines (98 loc) · 4.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# coding: utf-8
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from dataset import MyDataset
from loss import bmn_loss, get_mask
from model import BMN_model
from opt import MyConfig
from utils.opt_utils import get_cur_time_stamp
# GPU setting.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # range GPU in order
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# Basic test.
print("Pytorch's version is {}.".format(torch.__version__))
print("CUDNN's version is {}.".format(torch.backends.cudnn.version()))
print("CUDA's state is {}.".format(torch.cuda.is_available()))
print("CUDA's version is {}.".format(torch.version.cuda))
print("GPU's type is {}.".format(torch.cuda.get_device_name(0)))
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
if __name__ == "__main__":
opt = MyConfig()
opt.parse()
start_time = str(get_cur_time_stamp())
if not os.path.exists(opt.checkpoint_path):
os.makedirs(opt.checkpoint_path)
if not os.path.exists(opt.save_path):
os.makedirs(opt.save_path)
model = BMN_model(opt)
model = nn.DataParallel(model).cuda()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.learning_rate, weight_decay=opt.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.step_size, gamma=opt.step_gamma)
if opt.train_from_checkpoint:
checkpoint = torch.load(opt.checkpoint_path + '9_param.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
else:
start_epoch = 1
train_dataset = MyDataset(opt)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True,
num_workers=opt.num_workers, pin_memory=True)
valid_dataset = MyDataset(opt)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=True,
num_workers=opt.num_workers, pin_memory=True)
valid_best_loss = float('inf')
for epoch in tqdm(range(start_epoch, opt.epochs + 1)):
# Train.
model.train()
torch.cuda.empty_cache()
epoch_train_loss = 0
for train_iter, train_data in tqdm(enumerate(train_dataloader, start=1)):
optimizer.zero_grad()
video_feature, gt_iou_map, start_score, end_score = train_data
video_feature = video_feature.cuda()
gt_iou_map = gt_iou_map.cuda()
start_score = start_score.cuda()
end_score = end_score.cuda()
bm_confidence_map, start, end = model(video_feature)
bm_mask = get_mask(opt.temporal_scale).cuda()
# train_loss: total_loss, tem_loss, pem_reg_loss, pem_cls_loss
train_loss = bmn_loss(bm_confidence_map, start, end, gt_iou_map, start_score, end_score, bm_mask)
train_loss[0].backward()
optimizer.step()
epoch_train_loss = epoch_train_loss + train_loss[0].item()
scheduler.step()
# Valid.
epoch_valid_loss = 0
with torch.no_grad():
model.eval()
for valid_iter, valid_data in enumerate(valid_dataloader, start=1):
video_feature, gt_iou_map, start_score, end_score = valid_data
video_feature = video_feature.cuda()
gt_iou_map = gt_iou_map.cuda()
start_score = start_score.cuda()
end_score = end_score.cuda()
bm_confidence_map, start, end = model(video_feature)
valid_loss = bmn_loss(bm_confidence_map, start, end, gt_iou_map, start_score, end_score, bm_mask)
epoch_valid_loss = epoch_valid_loss + valid_loss[0].item()
if epoch <= 10 or epoch % 5 == 0:
print('Epoch {}: Training loss {:.3}, Validation loss {:.3}'.format(
epoch, float(epoch_train_loss/train_iter), float(epoch_valid_loss/valid_iter)))
with open(opt.save_path + start_time + '/log.txt', 'a') as f:
f.write('Epoch {}: Training loss {:.3}, Validation loss {:.3} \n'.format(
epoch, float(epoch_train_loss/train_iter), float(epoch_valid_loss/valid_iter)))
if epoch_valid_loss < valid_best_loss:
# Save parameters.
checkpoint = {'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch}
torch.save(checkpoint, opt.save_path + start_time + '/' + str(epoch) + '_param.pth.tar')
valid_best_loss = epoch_valid_loss
# Save whole model.
# torch.save(model, opt.save_path)