-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_lstm.py
81 lines (66 loc) · 2.86 KB
/
bert_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from transformers import BertTokenizer,BertModel
torch.manual_seed(2020)
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
torch.cuda.manual_seed(2020)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class bert_lstm(nn.Module):
def __init__(self, hidden_dim,output_size,n_layers,bidirectional=True, drop_prob=0.5):
super(bert_lstm, self).__init__()
self.output_size = output_size
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.bidirectional = bidirectional
#Bert ----------------重点,bert模型需要嵌入到自定义模型里面
self.bert=BertModel.from_pretrained("./chinese-bert_chinese_wwm_pytorch/data")
for param in self.bert.parameters():
param.requires_grad = True
# LSTM
self.lstm = nn.LSTM(768, hidden_dim, n_layers, batch_first=True,bidirectional=bidirectional)
# dropout
self.dropout = nn.Dropout(drop_prob)
# linear
if bidirectional:
self.fc = nn.Linear(hidden_dim*2, output_size)
else:
self.fc = nn.Linear(hidden_dim, output_size)
def forward(self, x, hidden):
batch_size = x.size(0)
#生成bert字向量
x=self.bert(x)[0] #bert 字向量
lstm_out, (hidden_last,cn_last) = self.lstm(x, hidden)
#修改 双向的需要单独处理
if self.bidirectional:
#正向最后一层,最后一个时刻
hidden_last_L=hidden_last[-2]
#print(hidden_last_L.shape) #[32, 384]
#反向最后一层,最后一个时刻
hidden_last_R=hidden_last[-1]
#print(hidden_last_R.shape) #[32, 384]
#进行拼接
hidden_last_out=torch.cat([hidden_last_L,hidden_last_R],dim=-1)
#print(hidden_last_out.shape,'hidden_last_out') #[32, 768]
else:
hidden_last_out=hidden_last[-1] #[32, 384]
# dropout
out = self.dropout(hidden_last_out)
# linear
out = self.fc(out)
return out
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
number = 1
if self.bidirectional:
number = 2
if (USE_CUDA):
hidden = (weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float().cuda(),
weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float().cuda()
)
else:
hidden = (weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float(),
weight.new(self.n_layers*number, batch_size, self.hidden_dim).zero_().float()
)
return hidden