diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index b7b6be47dce..5bb5b230002 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -34,6 +34,7 @@ ] if WITH_MULTIMODAL: + from .chartqa import ChartQA from .coco_caption import COCOCaption from .coco_retrieval import COCORetrieval from .coco_vqa import COCOVQA @@ -54,5 +55,5 @@ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA' + 'VSR', 'VizWiz', 'OCRVQA', 'ChartQA' ]) diff --git a/mmpretrain/datasets/chartqa.py b/mmpretrain/datasets/chartqa.py new file mode 100644 index 00000000000..180eaa78970 --- /dev/null +++ b/mmpretrain/datasets/chartqa.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.utils import is_abs + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class ChartQA(BaseDataset): + """ChartQA dataset. + + dataset:https://github.com/vis-nlp/ChartQA + + folder structure: + data/chartqa + ├── test + │ ├── png + │ ├── tables + │ ├── test_human.json + │ └── test_augmented.json + ├── train + │ ├── png + │ ├── tables + │ ├── train_human.json + │ └── train_augmented.json + └── val + ├── png + ├── tables + ├── val_human.json + └── val_augmented.json + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _join_prefix(self): + # Automatically join annotation file path with `self.root` if + # `self.ann_file` is not an absolute path. + if not any(is_abs(sub_ann_file) + for sub_ann_file in self.ann_file) and self.ann_file: + self.ann_file = [ + osp.join(self.data_root, sub_ann_file) + for sub_ann_file in self.ann_file + ] + # Automatically join data directory with `self.root` if path value in + # `self.data_prefix` is not an absolute path. + for data_key, prefix in self.data_prefix.items(): + if isinstance(prefix, str): + if not is_abs(prefix): + self.data_prefix[data_key] = osp.join( + self.data_root, prefix) + else: + self.data_prefix[data_key] = prefix + else: + raise TypeError('prefix should be a string, but got ' + f'{type(prefix)}') + + def load_data_list(self) -> List[dict]: + """Load data list.""" + data_list = [] + + for sub_ann_file in self.ann_file: + + annotations = mmengine.load(sub_ann_file) + + for ann in annotations: + + # ann example + # { + # 'imgname': '41699051005347.png' + # 'query': 'How many food item i...bar graph?', + # 'label': '14' + # } + + data_info = dict(question=ann['query']) + data_info['image_id'] = ann['imgname'] + + img_path = mmengine.join_path(self.data_prefix['img_path'], + ann['imgname']) + + data_info['img_path'] = img_path + data_info['gt_answer'] = ann['label'] + + if 'human' in sub_ann_file: + data_info['sub_set'] = 'ChartQA-H' + elif 'augmented' in sub_ann_file: + data_info['sub_set'] = 'ChartQA-M' + else: + raise ValueError( + f'Do not support to subset {sub_ann_file}.') + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index 7f5a4f36b41..e0dee70f761 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .caption import COCOCaption +from .chartqa import ChartQARelaxACC from .gqa import GQAAcc from .multi_label import AveragePrecision, MultiLabelMetric from .multi_task import MultiTasksMetric @@ -16,5 +17,5 @@ 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', - 'RetrievalAveragePrecision' + 'RetrievalAveragePrecision', 'ChartQARelaxACC' ] diff --git a/mmpretrain/evaluation/metrics/chartqa.py b/mmpretrain/evaluation/metrics/chartqa.py new file mode 100644 index 00000000000..c3294499b38 --- /dev/null +++ b/mmpretrain/evaluation/metrics/chartqa.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS +from .vqa import _process_digit_article, _process_punctuation + + +@METRICS.register_module() +class ChartQARelaxACC(BaseMetric): + '''ChartQARelaxACC. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'ChartQARelaxACC' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + relax_thresh: float = 0.05): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + self.relax_thresh = relax_thresh + + def is_digit(self, x: str): + a = bool(re.match(r'^[+-]?\d+\.\d+$', x)) + b = str(x).isnumeric() + return any([a, b]) + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + sub_set = sample.get('sub_set') + + is_digit = self.is_digit(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'is_digit': is_digit, + 'sub_set': sub_set + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + ChartQA_H_acc = [] + ChartQA_M_acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = result['gt_answer'] + is_digit = result['is_digit'] + sub_set = result['sub_set'] + + if is_digit: + if self.is_digit(pred_answer): + pred_answer = float(pred_answer) + gt_answer = float(gt_answer) + upper_bound = \ + max(gt_answer - gt_answer * self.relax_thresh, + gt_answer + gt_answer * self.relax_thresh) + lower_bound = \ + min(gt_answer - gt_answer * self.relax_thresh, + gt_answer + gt_answer * self.relax_thresh) + chart_acc = float( + all([ + pred_answer >= lower_bound, + pred_answer <= upper_bound + ])) + else: + chart_acc = 0.0 + else: + chart_acc = float(pred_answer == gt_answer) + + if sub_set == 'ChartQA-H': + ChartQA_H_acc.append(chart_acc) + elif sub_set == 'ChartQA-M': + ChartQA_M_acc.append(chart_acc) + else: + raise ValueError(f'Do not support to subset {sub_set}.') + + ChartQA_H_acc = sum(ChartQA_H_acc) / len(ChartQA_H_acc) * 100 + ChartQA_M_acc = sum(ChartQA_M_acc) / len(ChartQA_M_acc) * 100 + + accuracy = (ChartQA_H_acc + ChartQA_M_acc) / 2 + + metrics = { + 'ChartQA-H acc': ChartQA_H_acc, + 'ChartQA-M acc': ChartQA_M_acc, + 'Overall acc': accuracy + } + + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer