Skip to content

Commit

Permalink
optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiesutd committed Feb 19, 2019
1 parent ac3497c commit ab82b68
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ __pycache__
demo.clf.*
sent.*
*.out
*.log
13 changes: 7 additions & 6 deletions demo.clf.config
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ sentence_classification=True

### I/O ###
train_dir=../data/Sentclf/SST1/stsa.fine.train.clf
dev_dir=../data/Sentclf/SST1/stsa.fine.train.clf
dev_dir=../data/Sentclf/SST1/stsa.fine.dev.clf
test_dir=../data/Sentclf/SST1/stsa.fine.test.clf
model_dir=sample_data/clf
word_emb_dir=../data/glove.6B.100d.txt
word_emb_dir=../data/glove.840B.300d.txt


#raw_dir=
#decode_dir=
Expand All @@ -33,19 +34,19 @@ char_seq_feature=CNN

###TrainingSetting###
status=train
optimizer=Adadelta
iteration=10
optimizer=SGD
iteration=50
batch_size=10
ave_batch_loss=False

###Hyperparameters###
cnn_layer=4
char_hidden_dim=50
hidden_dim=400
dropout=0.5
dropout=0
lstm_layer=1
bilstm=True
learning_rate=1
learning_rate=0.2
lr_decay=0.05
momentum=0
l2=1e-8
Expand Down
42 changes: 36 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-06-15 14:11:08
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-14 16:09:16
# @Last Modified time: 2019-02-13 12:41:44

from __future__ import print_function
import time
Expand Down Expand Up @@ -51,6 +51,8 @@ def predict_check(pred_variable, gold_variable, mask_variable, sentence_classifi
mask = mask_variable.cpu().data.numpy()
overlaped = (pred == gold)
if sentence_classification:
# print(overlaped)
# print(overlaped*pred)
right_token = np.sum(overlaped)
total_token = overlaped.shape[0] ## =batch_size
else:
Expand Down Expand Up @@ -359,7 +361,7 @@ def train(data):
model = SentClassifier(data)
else:
model = SeqLabel(data)
# loss_function = nn.NLLLoss()

if data.optimizer.lower() == "sgd":
optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum,weight_decay=data.HP_l2)
elif data.optimizer.lower() == "adagrad":
Expand Down Expand Up @@ -407,7 +409,7 @@ def train(data):
continue
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True, data.sentence_classification)
instance_count += 1
loss, tag_seq = model.neg_log_likelihood_loss(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask)
loss, tag_seq = model.calculate_loss(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask)
right, whole = predict_check(tag_seq, batch_label, mask, data.sentence_classification)
right_token += right
whole_token += whole
Expand Down Expand Up @@ -504,13 +506,41 @@ def load_model_decode(data, name):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Tuning with NCRF++')
# parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train')
parser.add_argument('--config', help='Configuration File' )
parser.add_argument('--config', help='Configuration File', default='None')
parser.add_argument('--wordemb', help='Embedding for words', default='None')
parser.add_argument('--charemb', help='Embedding for chars', default='None')
parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train')
parser.add_argument('--savemodel', default="data/model/saved_model.lstmcrf.")
parser.add_argument('--savedset', help='Dir of saved data setting')
parser.add_argument('--train', default="data/conll03/train.bmes")
parser.add_argument('--dev', default="data/conll03/dev.bmes" )
parser.add_argument('--test', default="data/conll03/test.bmes")
parser.add_argument('--seg', default="True")
parser.add_argument('--raw')
parser.add_argument('--loadmodel')
parser.add_argument('--output')

args = parser.parse_args()
data = Data()
data.HP_gpu = torch.cuda.is_available()
data.read_config(args.config)
data.show_data_summary()
if args.config == 'None':
data.train_dir = args.train
data.dev_dir = args.dev
data.test_dir = args.test
data.model_dir = args.savemodel
data.dset_dir = args.savedset
print("Save dset directory:",data.dset_dir)
save_model_dir = args.savemodel
data.word_emb_dir = args.wordemb
data.char_emb_dir = args.charemb
if args.seg.lower() == 'true':
data.seg = True
else:
data.seg = False
print("Seed num:",seed_num)
else:
data.read_config(args.config)
# data.show_data_summary()
status = data.status.lower()
print("Seed num:",seed_num)

Expand Down
4 changes: 2 additions & 2 deletions main_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-06-15 14:11:08
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-18 21:02:01
# @Last Modified time: 2019-02-13 10:58:43

from __future__ import print_function
import time
Expand Down Expand Up @@ -434,7 +434,7 @@ def load_model_decode(data, name):
data.test_dir = args.test
data.model_dir = args.savemodel
data.dset_dir = args.savedset
print("aaa",data.dset_dir)
print("dset directory:",data.dset_dir)
status = args.status.lower()
save_model_dir = args.savemodel
data.HP_gpu = torch.cuda.is_available()
Expand Down
7 changes: 5 additions & 2 deletions model/sentclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie Yang
# @Date: 2019-01-01 21:11:50
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-14 14:56:28
# @Last Modified time: 2019-02-13 12:30:56

from __future__ import print_function
from __future__ import absolute_import
Expand All @@ -27,7 +27,7 @@ def __init__(self, data):



def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask):
def calculate_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask):
outs = self.word_hidden.sentence_representation(word_inputs,feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
batch_size = word_inputs.size(0)
# loss_function = nn.CrossEntropyLoss(ignore_index=0, reduction='sum')
Expand All @@ -40,6 +40,7 @@ def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths,
# exit(0)
total_loss = F.cross_entropy(outs, batch_label.view(batch_size))
# total_loss = loss_function(score, batch_label.view(batch_size))

_, tag_seq = torch.max(outs, 1)
if self.average_batch:
total_loss = total_loss / batch_size
Expand All @@ -51,6 +52,8 @@ def forward(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, ch
batch_size = word_inputs.size(0)
outs = outs.view(batch_size, -1)
_, tag_seq = torch.max(outs, 1)
# if a == 0:
# print(tag_seq)
return tag_seq


4 changes: 2 additions & 2 deletions model/seqlabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie Yang
# @Date: 2017-10-17 16:47:32
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-01 21:10:00
# @Last Modified time: 2019-02-13 11:49:38

from __future__ import print_function
from __future__ import absolute_import
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self, data):
self.crf = CRF(label_size, self.gpu)


def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask):
def calculate_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask):
outs = self.word_hidden(word_inputs,feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
batch_size = word_inputs.size(0)
seq_len = word_inputs.size(1)
Expand Down
7 changes: 6 additions & 1 deletion model/wordrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie Yang
# @Date: 2017-10-17 16:47:32
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-10 16:41:16
# @Last Modified time: 2019-02-01 15:52:01
from __future__ import print_function
from __future__ import absolute_import
import torch
Expand Down Expand Up @@ -87,7 +87,9 @@ def forward(self, word_inputs,feature_inputs, word_seq_lengths, char_inputs, cha
"""
batch_size = word_inputs.size(0)
sent_len = word_inputs.size(1)

word_embs = self.word_embedding(word_inputs)

word_list = [word_embs]
if not self.sentence_classification:
for idx in range(self.feature_num):
Expand All @@ -109,5 +111,8 @@ def forward(self, word_inputs,feature_inputs, word_seq_lengths, char_inputs, cha
## concat word and char together
word_list.append(char_features_extra)
word_embs = torch.cat(word_list, 2)
# if a == 0:
# print("inputs", word_inputs)
# print("embeddings:", word_embs)
word_represent = self.drop(word_embs)
return word_represent
2 changes: 1 addition & 1 deletion model/wordsequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie Yang
# @Date: 2017-10-17 16:47:32
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-11 13:55:41
# @Last Modified time: 2019-02-01 15:59:26
from __future__ import print_function
from __future__ import absolute_import
import torch
Expand Down
8 changes: 5 additions & 3 deletions utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-06-15 14:23:06
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-14 11:08:45
# @Last Modified time: 2019-02-14 12:23:52
from __future__ import print_function
from __future__ import absolute_import
import sys
Expand Down Expand Up @@ -204,8 +204,10 @@ def load_pretrain_emb(embedding_path):
tokens = line.split()
if embedd_dim < 0:
embedd_dim = len(tokens) - 1
else:
assert (embedd_dim + 1 == len(tokens))
elif embedd_dim + 1 != len(tokens):
## ignore illegal embedding line
continue
# assert (embedd_dim + 1 == len(tokens))
embedd = np.empty([1, embedd_dim])
embedd[:] = tokens[1:]
if sys.version_info[0] < 3:
Expand Down
4 changes: 2 additions & 2 deletions utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-02-16 09:53:19
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2019-01-18 21:04:10
# @Last Modified time: 2019-02-17 22:46:59

# from operator import add
#
Expand All @@ -27,7 +27,7 @@ def get_ner_fmeasure(golden_lists, predict_lists, label_type="BMES"):
if golden_list[idy] == predict_list[idy]:
right_tag += 1
all_tag += len(golden_list)
if label_type == "BMES" or "BIOES":
if label_type == "BMES" or label_type == "BIOES":
gold_matrix = get_ner_BMES(golden_list)
pred_matrix = get_ner_BMES(predict_list)
else:
Expand Down

0 comments on commit ab82b68

Please sign in to comment.