Skip to content

Commit

Permalink
add option to read in valid tagset from file (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
aflueckiger committed Jun 4, 2020
1 parent 0309a23 commit cd8eb04
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 48 deletions.
22 changes: 18 additions & 4 deletions clef_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def parse_args():
help="Suffix to append at output file names",
)

parser.add_argument(
"--tagset", action="store", dest="f_tagset", help="file containing the valid tagset",
)

return parser.parse_args()


Expand Down Expand Up @@ -142,13 +146,13 @@ def enforce_filename(fname):
return submission, lang


def evaluation_wrapper(evaluator, eval_type, cols, n_best=1):
def evaluation_wrapper(evaluator, eval_type, cols, n_best=1, tags=None):
eval_global = {}
eval_per_tag = {}

for col in cols:
eval_global[col], eval_per_tag[col] = evaluator.evaluate(
col, eval_type=eval_type, tags=None, merge_lines=True, n_best=n_best
col, eval_type=eval_type, tags=tags, merge_lines=True, n_best=n_best
)

# add aggregated stats across types as artificial tag
Expand All @@ -167,6 +171,7 @@ def get_results(
union=False,
outdir=".",
suffix="",
f_tagset=None,
):

if not skip_check:
Expand All @@ -188,14 +193,22 @@ def get_results(
else:
glueing_col_pairs = None

if f_tagset:
with open(f_tagset) as f_in:
tagset = set(f_in.read().upper().splitlines())
else:
tagset = None

evaluator = Evaluator(f_ref, f_pred, glueing_col_pairs)

if task == "nerc_fine":
eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=FINE_COLUMNS)
eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=FINE_COLUMNS, tags=tagset)
fieldnames, rows = assemble_tsv_output(submission, eval_stats)

elif task == "nerc_coarse":
eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=COARSE_COLUMNS)
eval_stats = evaluation_wrapper(
evaluator, eval_type="nerc", cols=COARSE_COLUMNS, tags=tagset
)
fieldnames, rows = assemble_tsv_output(submission, eval_stats)

elif task == "nel" and not union:
Expand Down Expand Up @@ -358,6 +371,7 @@ def main():
args.union,
args.outdir,
args.suffix,
args.f_tagset,
)
except AssertionError:
# don't interrupt the pipeline
Expand Down
31 changes: 18 additions & 13 deletions ner_evaluation/ner_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
collect_link_objects,
get_all_tags,
column_selector,
check_tag_selection,
)


Expand Down Expand Up @@ -162,7 +161,7 @@ def reconstruct_segmentation(self):
self.pred = docs_pred

def evaluate(
self, columns: list, eval_type: str, tags: list = None, merge_lines=False, n_best=1
self, columns: list, eval_type: str, tags: set = None, merge_lines=False, n_best=1
):
"""Collect extensive statistics across labels and per entity type.
Expand All @@ -174,7 +173,7 @@ def evaluate(
:param list columns: name of column that contains the annotations.
:param str eval_type: define evaluation type for either links (nel) or entities (nerc).
:param list tags: Description of parameter `tags`.
:param set tags: limit evaluation to valid tag set.
:param bool merge_lines: option to drop line segmentation to allow entity spans across lines.
:param int n_best: number of alternative links that should be considered.
:return: Aggregated statistics across labels and per entity type .
Expand All @@ -188,7 +187,6 @@ def evaluate(
logging.info(f"Evaluating column {columns} in system response file '{self.f_pred}'")

tags = self.set_evaluation_tags(columns, tags, eval_type)
logging.info(f"Evaluation on the following tags: {tags}")

# Create an accumulator to store overall results
results = deepcopy(self.metric_schema)
Expand Down Expand Up @@ -553,12 +551,20 @@ def set_evaluation_tags(self, columns, tags, eval_type):
pred_tags = get_all_tags(y_pred)

if tags:
logging.info(f"Provided tags for the column {columns}: {tags}")
tags = check_tag_selection(y_true, tags)
logging.info(f"Evaluation is limited to the provided tag set: {tags}")
self.check_spurious_tags(tags, pred_tags, columns)

# take the union of the actual gold standard labels and
# labels of the response file that are valid even when not included
# in gold standard of this particular column
# Other spurious tags are treated as non-entity ('O' tag).

tags = true_tags | {tag for tag in pred_tags if tag in tags}

elif eval_type == "nerc":
# For NERC, only tags which are covered by the gold standard are considered
tags = true_tags
self.check_spurious_tags(y_true, y_pred, columns)
self.check_spurious_tags(true_tags, pred_tags, columns)

if not pred_tags:
msg = f"No tags in the column '{columns}' of the system response file: '{self.f_pred}'"
Expand All @@ -568,21 +574,20 @@ def set_evaluation_tags(self, columns, tags, eval_type):
# For NEL, any tag in gold standard or predictions are considered
tags = true_tags | pred_tags

logging.info(f"Evaluating on the following tags: {tags}")

return tags

def check_spurious_tags(self, y_true: list, y_pred: list, columns: list):
def check_spurious_tags(self, tags_true: set, tags_pred: set, columns: list):
"""Log any tags of the system response which are not in the gold standard.
:param list y_true: a nested list of gold labels with the structure "[docs [sents [tokens]]]".
:param list y_pred: a nested list of system labels with the structure "[docs [sents [tokens]]]".
:param list tags_true: a set of true labels".
:param list tags_pred: a set of system labels".
:return: None.
:rtype: None
"""

tags_true = get_all_tags(y_true)
tags_pred = get_all_tags(y_pred)

for pred in tags_pred:
if pred not in tags_true:
msg = f"Spurious entity label '{pred}' in column {columns} of system response file: '{self.f_pred}'. As the tag is not part of the gold standard, it is ignored in the evaluation."
Expand Down
34 changes: 3 additions & 31 deletions ner_evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,44 +63,16 @@ def get_all_tags(y_true):
return tags


def check_tag_selection(y_cand: list, tags_ref: list):
"""Select only tags that are in the reference set and log dismissed tags.
:param list y_cand: a nested list of labels with the structure "[docs [sents [tokens]]]".
:param list tags_ref: a list of of reference tags.
:return: a set with cleaned tags according to the reference
:rtype: set
"""

tags_cand = get_all_tags(y_cand)

clean_tags = set()

for tag in tags_cand:
if tag not in tags_ref:
logging.info(
f"Selected tag '{tag}' is not covered by the gold data set and ignored for in the evaluation."
)
else:
clean_tags.add(tag)

return clean_tags


def check_spurious_tags(y_true: list, y_pred: list, columns: list):
def check_spurious_tags(tags_true: set, tags_pred: set, columns: list):
"""Log any tags of the system response which are not in the gold standard.
:param list y_true: a nested list of gold labels with the structure "[docs [sents [tokens]]]".
:param list y_pred: a nested list of system labels with the structure "[docs [sents [tokens]]]".
:param list tags_true: a set of true labels".
:param list tags_pred: a set of system labels".
:return: None.
:rtype: None
"""

tags_true = get_all_tags(y_true)
tags_pred = get_all_tags(y_pred)

for pred in tags_pred:
if pred not in tags_true:
msg = f"Spurious entity label '{pred}' in column {columns} in system response, which is part of the gold standard. Tag is ignored in the evaluation."
Expand Down
32 changes: 32 additions & 0 deletions tagset.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
comp.demonym
comp.function
comp.name
comp.qualifier
comp.title
loc
loc.add.elec
loc.add.phys
loc.adm.nat
loc.adm.reg
loc.adm.sup
loc.adm.town
loc.fac
loc.oro
loc.phys.astro
loc.phys.geo
loc.phys.hydro
loc.unk
org
org.adm
org.ent
org.ent.pressagency
pers
pers.coll
pers.ind
pers.ind.articleauthor
prod
prod.doctr
prod.media
time
time.date.abs
time.hour.abs

0 comments on commit cd8eb04

Please sign in to comment.