-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask.py
86 lines (66 loc) · 2.45 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import random
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from model import HETextCNN
def binary_acc(preds, y):
correct = torch.eq(preds, y).float()
acc = correct.sum() / len(correct)
return acc
test_set_size = 16
class Task:
def __init__(self, dataloader, left, right):
# Text-CNN Parameter
self.loader = dataloader
sequence_length = 40
vocab_size = self.loader.TEXT.vocab.vectors.shape[0]
embedding_size = self.loader.TEXT.vocab.vectors.shape[1]
num_classes = 2 # 0 or 1
filter_sizes = [2, 3, 5] # n-gram window
num_filters = 2
self.train_batches = []
for batch_idx , batch in enumerate(self.loader.train_iter):
if batch_idx >= left and batch_idx < right:
self.train_batches.append(batch)
# pretrained embedding
self.model = HETextCNN(num_filters, filter_sizes, vocab_size, embedding_size, sequence_length, num_classes)
pretrained_embedding = self.loader.TEXT.vocab.vectors
self.model.embedding.weight.data.copy_(pretrained_embedding)
# by copy weights
def copyfrom(self, model):
self.model.aggregate([model], [1])
# by reference
def update(self, model):
self.model = model
# 1 batch
def train(self):
self.model.train()
batch = self.train_batches[random.randint(0, len(self.train_batches)-1)]
text, labels = batch.text, batch.label - 1
predicted = self.model(text)
return predicted, labels
# all batch
def backpropagation(self, optimizer, predicted, labels):
self.model.train()
criterion = nn.CrossEntropyLoss()
acc = binary_acc(torch.max(predicted, dim=1)[1], labels)
loss = criterion(predicted, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.model, acc
# several batch
def evaluate(self, model):
avg_acc = []
criterion = nn.CrossEntropyLoss()
self.model.eval()
for batch_idx , batch in enumerate(self.loader.dev_iter):
if batch_idx >= test_set_size:
continue
text, labels = batch.text, batch.label - 1
predicted = self.model(text)
acc = binary_acc(torch.max(predicted, dim=1)[1], labels)
avg_acc.append(acc)
return np.array(avg_acc).mean()