-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
67 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ __pycache__ | |
demo.clf.* | ||
sent.* | ||
*.out | ||
*.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie | ||
# @Date: 2017-06-15 14:11:08 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-14 16:09:16 | ||
# @Last Modified time: 2019-02-13 12:41:44 | ||
|
||
from __future__ import print_function | ||
import time | ||
|
@@ -51,6 +51,8 @@ def predict_check(pred_variable, gold_variable, mask_variable, sentence_classifi | |
mask = mask_variable.cpu().data.numpy() | ||
overlaped = (pred == gold) | ||
if sentence_classification: | ||
# print(overlaped) | ||
# print(overlaped*pred) | ||
right_token = np.sum(overlaped) | ||
total_token = overlaped.shape[0] ## =batch_size | ||
else: | ||
|
@@ -359,7 +361,7 @@ def train(data): | |
model = SentClassifier(data) | ||
else: | ||
model = SeqLabel(data) | ||
# loss_function = nn.NLLLoss() | ||
|
||
if data.optimizer.lower() == "sgd": | ||
optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum,weight_decay=data.HP_l2) | ||
elif data.optimizer.lower() == "adagrad": | ||
|
@@ -407,7 +409,7 @@ def train(data): | |
continue | ||
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True, data.sentence_classification) | ||
instance_count += 1 | ||
loss, tag_seq = model.neg_log_likelihood_loss(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask) | ||
loss, tag_seq = model.calculate_loss(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask) | ||
right, whole = predict_check(tag_seq, batch_label, mask, data.sentence_classification) | ||
right_token += right | ||
whole_token += whole | ||
|
@@ -504,13 +506,41 @@ def load_model_decode(data, name): | |
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Tuning with NCRF++') | ||
# parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train') | ||
parser.add_argument('--config', help='Configuration File' ) | ||
parser.add_argument('--config', help='Configuration File', default='None') | ||
parser.add_argument('--wordemb', help='Embedding for words', default='None') | ||
parser.add_argument('--charemb', help='Embedding for chars', default='None') | ||
parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train') | ||
parser.add_argument('--savemodel', default="data/model/saved_model.lstmcrf.") | ||
parser.add_argument('--savedset', help='Dir of saved data setting') | ||
parser.add_argument('--train', default="data/conll03/train.bmes") | ||
parser.add_argument('--dev', default="data/conll03/dev.bmes" ) | ||
parser.add_argument('--test', default="data/conll03/test.bmes") | ||
parser.add_argument('--seg', default="True") | ||
parser.add_argument('--raw') | ||
parser.add_argument('--loadmodel') | ||
parser.add_argument('--output') | ||
|
||
args = parser.parse_args() | ||
data = Data() | ||
data.HP_gpu = torch.cuda.is_available() | ||
data.read_config(args.config) | ||
data.show_data_summary() | ||
if args.config == 'None': | ||
data.train_dir = args.train | ||
data.dev_dir = args.dev | ||
data.test_dir = args.test | ||
data.model_dir = args.savemodel | ||
data.dset_dir = args.savedset | ||
print("Save dset directory:",data.dset_dir) | ||
save_model_dir = args.savemodel | ||
data.word_emb_dir = args.wordemb | ||
data.char_emb_dir = args.charemb | ||
if args.seg.lower() == 'true': | ||
data.seg = True | ||
else: | ||
data.seg = False | ||
print("Seed num:",seed_num) | ||
else: | ||
data.read_config(args.config) | ||
# data.show_data_summary() | ||
status = data.status.lower() | ||
print("Seed num:",seed_num) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie | ||
# @Date: 2017-06-15 14:11:08 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-18 21:02:01 | ||
# @Last Modified time: 2019-02-13 10:58:43 | ||
|
||
from __future__ import print_function | ||
import time | ||
|
@@ -434,7 +434,7 @@ def load_model_decode(data, name): | |
data.test_dir = args.test | ||
data.model_dir = args.savemodel | ||
data.dset_dir = args.savedset | ||
print("aaa",data.dset_dir) | ||
print("dset directory:",data.dset_dir) | ||
status = args.status.lower() | ||
save_model_dir = args.savemodel | ||
data.HP_gpu = torch.cuda.is_available() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie Yang | ||
# @Date: 2019-01-01 21:11:50 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-14 14:56:28 | ||
# @Last Modified time: 2019-02-13 12:30:56 | ||
|
||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
|
@@ -27,7 +27,7 @@ def __init__(self, data): | |
|
||
|
||
|
||
def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask): | ||
def calculate_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask): | ||
outs = self.word_hidden.sentence_representation(word_inputs,feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover) | ||
batch_size = word_inputs.size(0) | ||
# loss_function = nn.CrossEntropyLoss(ignore_index=0, reduction='sum') | ||
|
@@ -40,6 +40,7 @@ def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths, | |
# exit(0) | ||
total_loss = F.cross_entropy(outs, batch_label.view(batch_size)) | ||
# total_loss = loss_function(score, batch_label.view(batch_size)) | ||
|
||
_, tag_seq = torch.max(outs, 1) | ||
if self.average_batch: | ||
total_loss = total_loss / batch_size | ||
|
@@ -51,6 +52,8 @@ def forward(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, ch | |
batch_size = word_inputs.size(0) | ||
outs = outs.view(batch_size, -1) | ||
_, tag_seq = torch.max(outs, 1) | ||
# if a == 0: | ||
# print(tag_seq) | ||
return tag_seq | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie Yang | ||
# @Date: 2017-10-17 16:47:32 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-01 21:10:00 | ||
# @Last Modified time: 2019-02-13 11:49:38 | ||
|
||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
|
@@ -33,7 +33,7 @@ def __init__(self, data): | |
self.crf = CRF(label_size, self.gpu) | ||
|
||
|
||
def neg_log_likelihood_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask): | ||
def calculate_loss(self, word_inputs, feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask): | ||
outs = self.word_hidden(word_inputs,feature_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover) | ||
batch_size = word_inputs.size(0) | ||
seq_len = word_inputs.size(1) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie Yang | ||
# @Date: 2017-10-17 16:47:32 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-10 16:41:16 | ||
# @Last Modified time: 2019-02-01 15:52:01 | ||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
import torch | ||
|
@@ -87,7 +87,9 @@ def forward(self, word_inputs,feature_inputs, word_seq_lengths, char_inputs, cha | |
""" | ||
batch_size = word_inputs.size(0) | ||
sent_len = word_inputs.size(1) | ||
|
||
word_embs = self.word_embedding(word_inputs) | ||
|
||
word_list = [word_embs] | ||
if not self.sentence_classification: | ||
for idx in range(self.feature_num): | ||
|
@@ -109,5 +111,8 @@ def forward(self, word_inputs,feature_inputs, word_seq_lengths, char_inputs, cha | |
## concat word and char together | ||
word_list.append(char_features_extra) | ||
word_embs = torch.cat(word_list, 2) | ||
# if a == 0: | ||
# print("inputs", word_inputs) | ||
# print("embeddings:", word_embs) | ||
word_represent = self.drop(word_embs) | ||
return word_represent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie Yang | ||
# @Date: 2017-10-17 16:47:32 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-11 13:55:41 | ||
# @Last Modified time: 2019-02-01 15:59:26 | ||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
import torch | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie | ||
# @Date: 2017-06-15 14:23:06 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-14 11:08:45 | ||
# @Last Modified time: 2019-02-14 12:23:52 | ||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
import sys | ||
|
@@ -204,8 +204,10 @@ def load_pretrain_emb(embedding_path): | |
tokens = line.split() | ||
if embedd_dim < 0: | ||
embedd_dim = len(tokens) - 1 | ||
else: | ||
assert (embedd_dim + 1 == len(tokens)) | ||
elif embedd_dim + 1 != len(tokens): | ||
## ignore illegal embedding line | ||
continue | ||
# assert (embedd_dim + 1 == len(tokens)) | ||
embedd = np.empty([1, embedd_dim]) | ||
embedd[:] = tokens[1:] | ||
if sys.version_info[0] < 3: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# @Author: Jie | ||
# @Date: 2017-02-16 09:53:19 | ||
# @Last Modified by: Jie Yang, Contact: [email protected] | ||
# @Last Modified time: 2019-01-18 21:04:10 | ||
# @Last Modified time: 2019-02-17 22:46:59 | ||
|
||
# from operator import add | ||
# | ||
|
@@ -27,7 +27,7 @@ def get_ner_fmeasure(golden_lists, predict_lists, label_type="BMES"): | |
if golden_list[idy] == predict_list[idy]: | ||
right_tag += 1 | ||
all_tag += len(golden_list) | ||
if label_type == "BMES" or "BIOES": | ||
if label_type == "BMES" or label_type == "BIOES": | ||
gold_matrix = get_ner_BMES(golden_list) | ||
pred_matrix = get_ner_BMES(predict_list) | ||
else: | ||
|