Skip to content

Commit

Permalink
Fix PATTERN_PATH initialization on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
yanlinf committed Oct 23, 2021
1 parent 9e8167b commit e199dcb
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions utils/grounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
import json
import string


__all__ = ['create_matcher_patterns', 'ground']


# the lemma of it/them/mine/.. is -PRON-

blacklist = set(["-PRON-", "actually", "likely", "possibly", "want",
"make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to",
"one", "something", "sometimes", "everybody", "somebody", "could", "could_be"
])


nltk.download('stopwords', quiet=True)
nltk_stopwords = nltk.corpus.stopwords.words('english')

Expand All @@ -37,10 +34,12 @@ def load_cpnet_vocab(cpnet_vocab_path):


def create_pattern(nlp, doc, debug=False):
pronoun_list = set(["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"])
pronoun_list = set(
["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"])
# Filtering concepts consisting of all stop words and longer than four words.
if len(doc) >= 5 or doc[0].text in pronoun_list or doc[-1].text in pronoun_list or \
all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token in doc]):
all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token
in doc]):
if debug:
return False, doc.text
return None # ignore this concept as pattern
Expand Down Expand Up @@ -81,7 +80,6 @@ def create_matcher_patterns(cpnet_vocab_path, output_path, debug=False):


def lemmatize(nlp, concept):

doc = nlp(concept.replace("_", " "))
lcs = set()
# for i in range(len(doc)):
Expand All @@ -108,13 +106,15 @@ def load_matcher(nlp, pattern_path):


def ground_qa_pair(qa_pair):
global nlp, matcher
if nlp is None or matcher is None:
s, a, pattern_path, cpnet_vocab_path = qa_pair

global nlp, matcher, CPNET_VOCAB
if nlp is None or matcher is None or CPNET_VOCAB is None:
nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])
nlp.add_pipe(nlp.create_pipe('sentencizer'))
matcher = load_matcher(nlp, PATTERN_PATH)
matcher = load_matcher(nlp, pattern_path)
CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)

s, a = qa_pair
all_concepts = ground_mentioned_concepts(nlp, matcher, s, a)
answer_concepts = ground_mentioned_concepts(nlp, matcher, a)
question_concepts = all_concepts - answer_concepts
Expand All @@ -131,7 +131,6 @@ def ground_qa_pair(qa_pair):


def ground_mentioned_concepts(nlp, matcher, s, ans=None):

s = s.lower()
doc = nlp(s)
matches = matcher(doc)
Expand Down Expand Up @@ -233,10 +232,13 @@ def hard_ground(nlp, sent, cpnet_vocab):
return res


def match_mentioned_concepts(sents, answers, num_processes):
res = []
def match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path):
input_args = []
for s, a in zip(sents, answers):
input_args.append((s, a, pattern_path, cpnet_vocab_path))

with Pool(num_processes) as p:
res = list(tqdm(p.imap(ground_qa_pair, zip(sents, answers)), total=len(sents)))
res = list(tqdm(p.imap(ground_qa_pair, input_args), total=len(sents)))
return res


Expand Down Expand Up @@ -296,10 +298,10 @@ def prune(data, cpnet_vocab_path):


def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False):
global PATTERN_PATH, CPNET_VOCAB
if PATTERN_PATH is None:
PATTERN_PATH = pattern_path
CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
# global PATTERN_PATH, CPNET_VOCAB
# if PATTERN_PATH is None:
# PATTERN_PATH = pattern_path
# CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)

sents = []
answers = []
Expand All @@ -324,7 +326,7 @@ def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_proc
print(ans)
answers.append(ans)

res = match_mentioned_concepts(sents, answers, num_processes)
res = match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path)
res = prune(res, cpnet_vocab_path)

# check_path(output_path)
Expand Down

0 comments on commit e199dcb

Please sign in to comment.