diff --git a/utils/grounding.py b/utils/grounding.py index 62ed497..99d23d9 100644 --- a/utils/grounding.py +++ b/utils/grounding.py @@ -6,10 +6,8 @@ import json import string - __all__ = ['create_matcher_patterns', 'ground'] - # the lemma of it/them/mine/.. is -PRON- blacklist = set(["-PRON-", "actually", "likely", "possibly", "want", @@ -17,7 +15,6 @@ "one", "something", "sometimes", "everybody", "somebody", "could", "could_be" ]) - nltk.download('stopwords', quiet=True) nltk_stopwords = nltk.corpus.stopwords.words('english') @@ -37,10 +34,12 @@ def load_cpnet_vocab(cpnet_vocab_path): def create_pattern(nlp, doc, debug=False): - pronoun_list = set(["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"]) + pronoun_list = set( + ["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"]) # Filtering concepts consisting of all stop words and longer than four words. if len(doc) >= 5 or doc[0].text in pronoun_list or doc[-1].text in pronoun_list or \ - all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token in doc]): + all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token + in doc]): if debug: return False, doc.text return None # ignore this concept as pattern @@ -81,7 +80,6 @@ def create_matcher_patterns(cpnet_vocab_path, output_path, debug=False): def lemmatize(nlp, concept): - doc = nlp(concept.replace("_", " ")) lcs = set() # for i in range(len(doc)): @@ -108,13 +106,15 @@ def load_matcher(nlp, pattern_path): def ground_qa_pair(qa_pair): - global nlp, matcher - if nlp is None or matcher is None: + s, a, pattern_path, cpnet_vocab_path = qa_pair + + global nlp, matcher, CPNET_VOCAB + if nlp is None or matcher is None or CPNET_VOCAB is None: nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) nlp.add_pipe(nlp.create_pipe('sentencizer')) - matcher = load_matcher(nlp, PATTERN_PATH) + matcher = load_matcher(nlp, pattern_path) + CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path) - s, a = qa_pair all_concepts = ground_mentioned_concepts(nlp, matcher, s, a) answer_concepts = ground_mentioned_concepts(nlp, matcher, a) question_concepts = all_concepts - answer_concepts @@ -131,7 +131,6 @@ def ground_qa_pair(qa_pair): def ground_mentioned_concepts(nlp, matcher, s, ans=None): - s = s.lower() doc = nlp(s) matches = matcher(doc) @@ -233,10 +232,13 @@ def hard_ground(nlp, sent, cpnet_vocab): return res -def match_mentioned_concepts(sents, answers, num_processes): - res = [] +def match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path): + input_args = [] + for s, a in zip(sents, answers): + input_args.append((s, a, pattern_path, cpnet_vocab_path)) + with Pool(num_processes) as p: - res = list(tqdm(p.imap(ground_qa_pair, zip(sents, answers)), total=len(sents))) + res = list(tqdm(p.imap(ground_qa_pair, input_args), total=len(sents))) return res @@ -296,10 +298,10 @@ def prune(data, cpnet_vocab_path): def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False): - global PATTERN_PATH, CPNET_VOCAB - if PATTERN_PATH is None: - PATTERN_PATH = pattern_path - CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path) + # global PATTERN_PATH, CPNET_VOCAB + # if PATTERN_PATH is None: + # PATTERN_PATH = pattern_path + # CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path) sents = [] answers = [] @@ -324,7 +326,7 @@ def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_proc print(ans) answers.append(ans) - res = match_mentioned_concepts(sents, answers, num_processes) + res = match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path) res = prune(res, cpnet_vocab_path) # check_path(output_path)