From cd8eb043d3d77c03d23fe75cddb61f086ee7eed7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Fl=C3=BCckiger?= Date: Thu, 4 Jun 2020 19:44:08 +0200 Subject: [PATCH] add option to read in valid tagset from file (#9) --- clef_evaluation.py | 22 ++++++++++++++++++---- ner_evaluation/ner_eval.py | 31 ++++++++++++++++++------------- ner_evaluation/utils.py | 34 +++------------------------------- tagset.txt | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 48 deletions(-) create mode 100644 tagset.txt diff --git a/clef_evaluation.py b/clef_evaluation.py index b220960..73e57c7 100644 --- a/clef_evaluation.py +++ b/clef_evaluation.py @@ -115,6 +115,10 @@ def parse_args(): help="Suffix to append at output file names", ) + parser.add_argument( + "--tagset", action="store", dest="f_tagset", help="file containing the valid tagset", + ) + return parser.parse_args() @@ -142,13 +146,13 @@ def enforce_filename(fname): return submission, lang -def evaluation_wrapper(evaluator, eval_type, cols, n_best=1): +def evaluation_wrapper(evaluator, eval_type, cols, n_best=1, tags=None): eval_global = {} eval_per_tag = {} for col in cols: eval_global[col], eval_per_tag[col] = evaluator.evaluate( - col, eval_type=eval_type, tags=None, merge_lines=True, n_best=n_best + col, eval_type=eval_type, tags=tags, merge_lines=True, n_best=n_best ) # add aggregated stats across types as artificial tag @@ -167,6 +171,7 @@ def get_results( union=False, outdir=".", suffix="", + f_tagset=None, ): if not skip_check: @@ -188,14 +193,22 @@ def get_results( else: glueing_col_pairs = None + if f_tagset: + with open(f_tagset) as f_in: + tagset = set(f_in.read().upper().splitlines()) + else: + tagset = None + evaluator = Evaluator(f_ref, f_pred, glueing_col_pairs) if task == "nerc_fine": - eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=FINE_COLUMNS) + eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=FINE_COLUMNS, tags=tagset) fieldnames, rows = assemble_tsv_output(submission, eval_stats) elif task == "nerc_coarse": - eval_stats = evaluation_wrapper(evaluator, eval_type="nerc", cols=COARSE_COLUMNS) + eval_stats = evaluation_wrapper( + evaluator, eval_type="nerc", cols=COARSE_COLUMNS, tags=tagset + ) fieldnames, rows = assemble_tsv_output(submission, eval_stats) elif task == "nel" and not union: @@ -358,6 +371,7 @@ def main(): args.union, args.outdir, args.suffix, + args.f_tagset, ) except AssertionError: # don't interrupt the pipeline diff --git a/ner_evaluation/ner_eval.py b/ner_evaluation/ner_eval.py index 0242cf5..5397b4b 100644 --- a/ner_evaluation/ner_eval.py +++ b/ner_evaluation/ner_eval.py @@ -17,7 +17,6 @@ collect_link_objects, get_all_tags, column_selector, - check_tag_selection, ) @@ -162,7 +161,7 @@ def reconstruct_segmentation(self): self.pred = docs_pred def evaluate( - self, columns: list, eval_type: str, tags: list = None, merge_lines=False, n_best=1 + self, columns: list, eval_type: str, tags: set = None, merge_lines=False, n_best=1 ): """Collect extensive statistics across labels and per entity type. @@ -174,7 +173,7 @@ def evaluate( :param list columns: name of column that contains the annotations. :param str eval_type: define evaluation type for either links (nel) or entities (nerc). - :param list tags: Description of parameter `tags`. + :param set tags: limit evaluation to valid tag set. :param bool merge_lines: option to drop line segmentation to allow entity spans across lines. :param int n_best: number of alternative links that should be considered. :return: Aggregated statistics across labels and per entity type . @@ -188,7 +187,6 @@ def evaluate( logging.info(f"Evaluating column {columns} in system response file '{self.f_pred}'") tags = self.set_evaluation_tags(columns, tags, eval_type) - logging.info(f"Evaluation on the following tags: {tags}") # Create an accumulator to store overall results results = deepcopy(self.metric_schema) @@ -553,12 +551,20 @@ def set_evaluation_tags(self, columns, tags, eval_type): pred_tags = get_all_tags(y_pred) if tags: - logging.info(f"Provided tags for the column {columns}: {tags}") - tags = check_tag_selection(y_true, tags) + logging.info(f"Evaluation is limited to the provided tag set: {tags}") + self.check_spurious_tags(tags, pred_tags, columns) + + # take the union of the actual gold standard labels and + # labels of the response file that are valid even when not included + # in gold standard of this particular column + # Other spurious tags are treated as non-entity ('O' tag). + + tags = true_tags | {tag for tag in pred_tags if tag in tags} + elif eval_type == "nerc": # For NERC, only tags which are covered by the gold standard are considered tags = true_tags - self.check_spurious_tags(y_true, y_pred, columns) + self.check_spurious_tags(true_tags, pred_tags, columns) if not pred_tags: msg = f"No tags in the column '{columns}' of the system response file: '{self.f_pred}'" @@ -568,21 +574,20 @@ def set_evaluation_tags(self, columns, tags, eval_type): # For NEL, any tag in gold standard or predictions are considered tags = true_tags | pred_tags + logging.info(f"Evaluating on the following tags: {tags}") + return tags - def check_spurious_tags(self, y_true: list, y_pred: list, columns: list): + def check_spurious_tags(self, tags_true: set, tags_pred: set, columns: list): """Log any tags of the system response which are not in the gold standard. - :param list y_true: a nested list of gold labels with the structure "[docs [sents [tokens]]]". - :param list y_pred: a nested list of system labels with the structure "[docs [sents [tokens]]]". + :param list tags_true: a set of true labels". + :param list tags_pred: a set of system labels". :return: None. :rtype: None """ - tags_true = get_all_tags(y_true) - tags_pred = get_all_tags(y_pred) - for pred in tags_pred: if pred not in tags_true: msg = f"Spurious entity label '{pred}' in column {columns} of system response file: '{self.f_pred}'. As the tag is not part of the gold standard, it is ignored in the evaluation." diff --git a/ner_evaluation/utils.py b/ner_evaluation/utils.py index 300bff4..1ac8cfa 100644 --- a/ner_evaluation/utils.py +++ b/ner_evaluation/utils.py @@ -63,44 +63,16 @@ def get_all_tags(y_true): return tags -def check_tag_selection(y_cand: list, tags_ref: list): - """Select only tags that are in the reference set and log dismissed tags. - - :param list y_cand: a nested list of labels with the structure "[docs [sents [tokens]]]". - :param list tags_ref: a list of of reference tags. - :return: a set with cleaned tags according to the reference - :rtype: set - - """ - - tags_cand = get_all_tags(y_cand) - - clean_tags = set() - - for tag in tags_cand: - if tag not in tags_ref: - logging.info( - f"Selected tag '{tag}' is not covered by the gold data set and ignored for in the evaluation." - ) - else: - clean_tags.add(tag) - - return clean_tags - - -def check_spurious_tags(y_true: list, y_pred: list, columns: list): +def check_spurious_tags(tags_true: set, tags_pred: set, columns: list): """Log any tags of the system response which are not in the gold standard. - :param list y_true: a nested list of gold labels with the structure "[docs [sents [tokens]]]". - :param list y_pred: a nested list of system labels with the structure "[docs [sents [tokens]]]". + :param list tags_true: a set of true labels". + :param list tags_pred: a set of system labels". :return: None. :rtype: None """ - tags_true = get_all_tags(y_true) - tags_pred = get_all_tags(y_pred) - for pred in tags_pred: if pred not in tags_true: msg = f"Spurious entity label '{pred}' in column {columns} in system response, which is part of the gold standard. Tag is ignored in the evaluation." diff --git a/tagset.txt b/tagset.txt new file mode 100644 index 0000000..e811425 --- /dev/null +++ b/tagset.txt @@ -0,0 +1,32 @@ +comp.demonym +comp.function +comp.name +comp.qualifier +comp.title +loc +loc.add.elec +loc.add.phys +loc.adm.nat +loc.adm.reg +loc.adm.sup +loc.adm.town +loc.fac +loc.oro +loc.phys.astro +loc.phys.geo +loc.phys.hydro +loc.unk +org +org.adm +org.ent +org.ent.pressagency +pers +pers.coll +pers.ind +pers.ind.articleauthor +prod +prod.doctr +prod.media +time +time.date.abs +time.hour.abs