diff --git a/main.py b/main.py index 9533ec5..82a710a 100644 --- a/main.py +++ b/main.py @@ -225,7 +225,7 @@ def batchify_sequence_labeling_with_label(input_batch_list, gpu, if_train=True): feature_seq_tensors = [] for idx in range(feature_num): feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long()) - mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte() + mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool() for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)): seqlen = seqlen.item() word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq) @@ -304,7 +304,7 @@ def batchify_sentence_classification_with_label(input_batch_list, gpu, if_train= feature_seq_tensors = [] for idx in range(feature_num): feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long()) - mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte() + mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).bool() label_seq_tensor = torch.LongTensor(labels) # exit(0) for idx, (seq, seqlen) in enumerate(zip(words, word_seq_lengths)): diff --git a/main_parse.py b/main_parse.py index 47e21f1..d1b0ffa 100644 --- a/main_parse.py +++ b/main_parse.py @@ -233,7 +233,7 @@ def batchify_with_label(input_batch_list, gpu, volatile_flag=False): feature_seq_tensors = [] for idx in range(feature_num): feature_seq_tensors.append(autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).long()) - mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).byte() + mask = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).bool() for idx, (seq, label, seqlen) in enumerate(zip(words, labels, word_seq_lengths)): word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq) label_seq_tensor[idx, :seqlen] = torch.LongTensor(label) diff --git a/model/crf.py b/model/crf.py index be69bde..863d4c4 100644 --- a/model/crf.py +++ b/model/crf.py @@ -133,7 +133,7 @@ def _viterbi_decode(self, feats, mask): partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size @@ -297,7 +297,7 @@ def _viterbi_decode_nbest(self, feats, mask, nbest): partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone() # bat_size * to_target_size