From 53651fc6428c1006111bb3b58f7e0be554a7cbe6 Mon Sep 17 00:00:00 2001 From: Shreyanand Date: Fri, 4 Nov 2022 20:42:16 +0000 Subject: [PATCH] Add WIP transformers kpi traning notebook Signed-off-by: Shreyanand --- notebooks/demo2/transformer_kpi.ipynb | 1348 +++++++++++++++++++++++++ 1 file changed, 1348 insertions(+) create mode 100644 notebooks/demo2/transformer_kpi.ipynb diff --git a/notebooks/demo2/transformer_kpi.ipynb b/notebooks/demo2/transformer_kpi.ipynb new file mode 100644 index 0000000..b7afa8d --- /dev/null +++ b/notebooks/demo2/transformer_kpi.ipynb @@ -0,0 +1,1348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ba2a359f-3b85-42b1-a0ce-3047904904e4", + "metadata": { + "tags": [] + }, + "source": [ + "## Hugging face kpi model\n", + "\n", + "This notebook first tries zero short learning with a bert model or in other words, direct prediction with a bert model on the climate relevance task. Then it fine tunes the bert model for the relevance task using the huggingface transformers package. " + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "92113bf6-e6b7-41d1-8f5d-fa554ba05e97", + "metadata": { + "papermill": { + "duration": 3.30074, + "end_time": "2022-10-07T19:33:18.885511", + "exception": false, + "start_time": "2022-10-07T19:33:15.584771", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import config\n", + "import os\n", + "import pathlib\n", + "from dotenv import load_dotenv\n", + "from src.data.s3_communication import S3Communication\n", + "import numpy as np\n", + "from transformers import DistilBertTokenizerFast\n", + "from transformers import TrainingArguments, Trainer\n", + "from transformers import AutoModelForQuestionAnswering\n", + "from datasets import load_dataset, load_metric\n", + "from transformers import AutoTokenizer\n", + "import json\n", + "from torch import cuda\n", + "import collections\n", + "from tqdm.auto import tqdm\n", + "import torch\n", + "from transformers import default_data_collator\n", + "device = 'cuda' if cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "03d92acc-d593-4aca-85d8-5cecde9559c0", + "metadata": { + "papermill": { + "duration": 0.011639, + "end_time": "2022-10-07T19:33:18.901426", + "exception": false, + "start_time": "2022-10-07T19:33:18.889787", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Load credentials\n", + "dotenv_dir = os.environ.get(\n", + " \"CREDENTIAL_DOTENV_DIR\", os.environ.get(\"PWD\", \"/opt/app-root/src\")\n", + ")\n", + "dotenv_path = pathlib.Path(dotenv_dir) / \"credentials.env\"\n", + "if os.path.exists(dotenv_path):\n", + " load_dotenv(dotenv_path=dotenv_path, override=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "77577e13-0810-4cf7-b6a3-dbae6e882ef3", + "metadata": { + "papermill": { + "duration": 0.090141, + "end_time": "2022-10-07T19:33:18.997018", + "exception": false, + "start_time": "2022-10-07T19:33:18.906877", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# init s3 connector\n", + "s3c = S3Communication(\n", + " s3_endpoint_url=os.getenv(\"S3_ENDPOINT\"),\n", + " aws_access_key_id=os.getenv(\"AWS_ACCESS_KEY_ID\"),\n", + " aws_secret_access_key=os.getenv(\"AWS_SECRET_ACCESS_KEY\"),\n", + " s3_bucket=os.getenv(\"S3_BUCKET\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4cc05548-04ea-4f74-bcc5-4b23c5c88d9c", + "metadata": { + "papermill": { + "duration": 0.003808, + "end_time": "2022-10-07T19:33:19.004776", + "exception": false, + "start_time": "2022-10-07T19:33:19.000968", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Retrieve the test dataset and the trained models" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f66f6932-e53f-470b-8a66-6d5f0905bb1c", + "metadata": { + "papermill": { + "duration": 0.886377, + "end_time": "2022-10-07T19:33:19.895045", + "exception": false, + "start_time": "2022-10-07T19:33:19.008668", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "s3c.download_files_in_prefix_to_dir(\n", + " config.BASE_TRAIN_TEST_DATASET_S3_PREFIX,\n", + " config.BASE_PROCESSED_DATA)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "34cd9699-eb97-4891-aefa-4165737e37fd", + "metadata": {}, + "outputs": [], + "source": [ + "test_data_path = str(config.BASE_PROCESSED_DATA)+'/kpi_test_split.json'\n", + "train_data_path = str(config.BASE_PROCESSED_DATA)+'/kpi_train_split.json'\n", + "\n", + "train_processed_data_path = str(config.BASE_PROCESSED_DATA)+'/kpi_processed_train_split.json'\n", + "test_processed_data_path = str(config.BASE_PROCESSED_DATA)+'/kpi_processed_test_split.json'" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "08584644-a10b-4d83-b600-538074489004", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_process_raw_data(data_path, output_path):\n", + " with open(data_path, \"r\") as read_file:\n", + " data = json.load(read_file)\n", + " squad = list()\n", + " for instance in data['data']:\n", + " instance_dic = dict()\n", + " title = instance['title']\n", + " for para in instance['paragraphs']:\n", + " context = para['context']\n", + " for ques in para['qas']:\n", + " question = ques['question']\n", + " qid = ques['id']\n", + " #is_impossible = ques['is_impossible']\n", + " answers = ques['answers']\n", + " instance_dic = {'id': qid,\n", + " 'title': title,\n", + " 'context': context,\n", + " 'question': question,\n", + " 'answers': answers}\n", + " squad.append(instance_dic)\n", + " with open(output_path, \"w\") as write_file:\n", + " json.dump(squad, write_file)\n", + "\n", + "\n", + "get_process_raw_data(train_data_path, train_processed_data_path)\n", + "get_process_raw_data(test_data_path, test_processed_data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c8246943-9d1e-4707-965c-b559c0f47c21", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using custom data configuration default-343a66958535ceab\n", + "Found cached dataset json (/opt/app-root/src/.cache/huggingface/datasets/json/default-343a66958535ceab/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70edd5290e5540db8d3d4d0467ded1c2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00= end_char):\n", + " tokenized_examples[\"start_positions\"].append(cls_index)\n", + " tokenized_examples[\"end_positions\"].append(cls_index)\n", + " else:\n", + " # Otherwise move the token_start_index and token_end_index to the two ends of the answer.\n", + " # Note: we could go after the last offset if the answer is the last word (edge case).\n", + " while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:\n", + " token_start_index += 1\n", + " tokenized_examples[\"start_positions\"].append(token_start_index - 1)\n", + " while offsets[token_end_index][1] >= end_char:\n", + " token_end_index -= 1\n", + " tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6089c758-33ed-4696-aa5f-d7839b13a2fc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 66220\n", + " })\n", + " test: Dataset({\n", + " features: ['id', 'title', 'context', 'question', 'answers'],\n", + " num_rows: 16891\n", + " })\n", + "})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "climate_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f049487b-5455-4d7a-a5ab-25e1e09dc1fd", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "00e36586b5834eb7a63988f4e1fdd585", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/67 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b4a33136830488c8afaf59219b97ea2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/17 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "tokenized_datasets = climate_dataset.map(prepare_train_features,\n", + " batched=True,\n", + " remove_columns=climate_dataset[\"train\"].column_names)" + ] + }, + { + "cell_type": "markdown", + "id": "280b1448-da31-44e8-bda0-a8891ac538cd", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "5c11ad1b-094e-4406-8437-cdc5b8cd137c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']\n", + "- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "interim_model_checkpoint = \"distilbert-base-uncased-finetuned-squad/checkpoint-4500\"\n", + "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint).to(device)\n", + "\n", + "model_name = model_checkpoint.split(\"/\")[-1]\n", + "args = TrainingArguments(\n", + " f\"{model_name}-finetuned-squad\",\n", + " evaluation_strategy = \"no\",\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " num_train_epochs=1,\n", + " weight_decay=0.01,\n", + " push_to_hub=False,\n", + ")\n", + "\n", + "data_collator = default_data_collator\n", + "\n", + "trainer = Trainer(\n", + " model,\n", + " args,\n", + " train_dataset=tokenized_datasets[\"train\"],\n", + " eval_dataset=None,\n", + " data_collator=data_collator,\n", + " tokenizer=tokenizer,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "7f2246ea-34c9-446e-a0b8-80db4c85754f", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/app-root/lib64/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "***** Running training *****\n", + " Num examples = 66449\n", + " Num Epochs = 1\n", + " Instantaneous batch size per device = 16\n", + " Total train batch size (w. parallel, distributed & accumulation) = 16\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 4154\n", + " Number of trainable parameters = 66364418\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [4154/4154 13:38, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
5000.324900
10000.077500
15000.056400
20000.045500
25000.052600
30000.047300
35000.045700
40000.037500

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-500\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-500/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-500/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-500/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-500/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-1000\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-1000/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-1000/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-1000/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-1000/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-1500\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-1500/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-1500/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-1500/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-1500/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-2000\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-2000/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-2000/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-2000/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-2000/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-2500\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-2500/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-2500/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-2500/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-2500/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-3000\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-3000/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-3000/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-3000/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-3000/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-3500\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-3500/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-3500/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-3500/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-3500/special_tokens_map.json\n", + "Saving model checkpoint to distilbert-base-uncased-finetuned-squad/checkpoint-4000\n", + "Configuration saved in distilbert-base-uncased-finetuned-squad/checkpoint-4000/config.json\n", + "Model weights saved in distilbert-base-uncased-finetuned-squad/checkpoint-4000/pytorch_model.bin\n", + "tokenizer config file saved in distilbert-base-uncased-finetuned-squad/checkpoint-4000/tokenizer_config.json\n", + "Special tokens file saved in distilbert-base-uncased-finetuned-squad/checkpoint-4000/special_tokens_map.json\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=4154, training_loss=0.08397273301045138, metrics={'train_runtime': 821.3564, 'train_samples_per_second': 80.902, 'train_steps_per_second': 5.057, 'total_flos': 6511325883019776.0, 'train_loss': 0.08397273301045138, 'epoch': 1.0})" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "19a400b7-3c86-4c02-ad55-64324ac5c4d3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION\n", + "Configuration saved in /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/config.json\n", + "Model weights saved in /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/pytorch_model.bin\n", + "tokenizer config file saved in /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/tokenizer_config.json\n", + "Special tokens file saved in /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/special_tokens_map.json\n" + ] + } + ], + "source": [ + "local_model_path = '/opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION'\n", + "trainer.save_model(local_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "a373c28a-c9a1-428d-8c27-46e22394b2e2", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loading file vocab.txt\n", + "loading file tokenizer.json\n", + "loading file added_tokens.json\n", + "loading file special_tokens_map.json\n", + "loading file tokenizer_config.json\n", + "loading configuration file /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/config.json\n", + "Model config DistilBertConfig {\n", + " \"_name_or_path\": \"/opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION\",\n", + " \"activation\": \"gelu\",\n", + " \"architectures\": [\n", + " \"DistilBertForQuestionAnswering\"\n", + " ],\n", + " \"attention_dropout\": 0.1,\n", + " \"dim\": 768,\n", + " \"dropout\": 0.1,\n", + " \"hidden_dim\": 3072,\n", + " \"initializer_range\": 0.02,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"distilbert\",\n", + " \"n_heads\": 12,\n", + " \"n_layers\": 6,\n", + " \"pad_token_id\": 0,\n", + " \"qa_dropout\": 0.1,\n", + " \"seq_classif_dropout\": 0.2,\n", + " \"sinusoidal_pos_embds\": false,\n", + " \"tie_weights_\": true,\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.24.0\",\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n", + "loading weights file /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION/pytorch_model.bin\n", + "All model checkpoint weights were used when initializing DistilBertForQuestionAnswering.\n", + "\n", + "All the weights of DistilBertForQuestionAnswering were initialized from the model checkpoint at /opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION.\n", + "If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForQuestionAnswering for predictions without further training.\n" + ] + } + ], + "source": [ + "local_model_path = '/opt/app-root/src/aicoe-osc-demo/models/transformers/KPI_EXTRACTION'\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=True)\n", + "model = AutoModelForQuestionAnswering.from_pretrained(local_model_path).to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "d7f11d4d-972a-4edf-b8fd-b672fc93195f", + "metadata": {}, + "source": [ + "## Evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "dd4d36fe-cd02-41ab-a686-3c894153d745", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],\n", + " num_rows: 16928\n", + "})" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized_datasets['test']" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "07afa3a5-a1b2-469e-9fc0-e2f4dceb9689", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "odict_keys(['start_logits', 'end_logits'])" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for batch in trainer.get_test_dataloader(tokenized_datasets['test'].remove_columns(['start_positions', 'end_positions'])):\n", + " break\n", + "batch = {k: v.to(trainer.args.device) for k, v in batch.items()}\n", + "with torch.no_grad():\n", + " output = model(**batch)\n", + "output.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "403ba900-708f-45c4-bcf3-e6e2577e1180", + "metadata": {}, + "outputs": [], + "source": [ + "n_best_size = 20\n", + "start_logits = output.start_logits[0].cpu().numpy()\n", + "end_logits = output.end_logits[0].cpu().numpy()\n", + "# Gather the indices the best start/end logits:\n", + "start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1 : -1].tolist()\n", + "end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1 : -1].tolist()\n", + "valid_answers = []\n", + "for start_index in start_indexes:\n", + " for end_index in end_indexes:\n", + " if start_index <= end_index: # We need to refine that test to check the answer is inside the context\n", + " valid_answers.append(\n", + " {\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\n", + " \"text\": \"\"# We need to find a way to get back the original\n", + " # substring corresponding to the answer in the context\n", + " }\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "9c052e73-70a0-45a2-984b-39826df0f834", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_validation_features(examples):\n", + " # Some of the questions have lots of whitespace on the left, which is not useful and will make the\n", + " # truncation of the context fail (the tokenized question will take a lots of space). So we remove that\n", + " # left whitespace\n", + " examples[\"question\"] = [q.lstrip() for q in examples[\"question\"]]\n", + "\n", + " # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results\n", + " # in one example possible giving several features when a context is long, each of those features having a\n", + " # context that overlaps a bit the context of the previous feature.\n", + " tokenized_examples = tokenizer(\n", + " examples[\"question\" if pad_on_right else \"context\"],\n", + " examples[\"context\" if pad_on_right else \"question\"],\n", + " truncation=\"only_second\" if pad_on_right else \"only_first\",\n", + " max_length=max_length,\n", + " stride=doc_stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + "\n", + " # Since one example might give us several features if it has a long context, we need a map from a feature to\n", + " # its corresponding example. This key gives us just that.\n", + " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", + "\n", + " # We keep the example_id that gave us this feature and we will store the offset mappings.\n", + " tokenized_examples[\"example_id\"] = []\n", + "\n", + " for i in range(len(tokenized_examples[\"input_ids\"])):\n", + " # Grab the sequence corresponding to that example (to know what is the context and what is the question).\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\n", + " context_index = 1 if pad_on_right else 0\n", + "\n", + " # One example can give several spans, this is the index of the example containing this span of text.\n", + " sample_index = sample_mapping[i]\n", + " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n", + "\n", + " # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token\n", + " # position is part of the context or not.\n", + " tokenized_examples[\"offset_mapping\"][i] = [\n", + " (o if sequence_ids[k] == context_index else None)\n", + " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n", + " ]\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "ed9d77ff-625f-4496-93be-2ff2cb198ece", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bcb05a2f1ef340a1ac69e750a11d43f2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/17 [00:00= len(offset_mapping)\n", + " or end_index >= len(offset_mapping)\n", + " or offset_mapping[start_index] is None\n", + " or offset_mapping[end_index] is None\n", + " ):\n", + " continue\n", + " # Don't consider answers with a length that is either < 0 or > max_answer_length.\n", + " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n", + " continue\n", + " if start_index <= end_index: # We need to refine that test to check the answer is inside the context\n", + " start_char = offset_mapping[start_index][0]\n", + " end_char = offset_mapping[end_index][1]\n", + " valid_answers.append(\n", + " {\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\n", + " \"text\": context[start_char: end_char]\n", + " }\n", + " )\n", + "\n", + "valid_answers = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[:n_best_size]\n", + "valid_answers" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "e5044081-7551-48c3-90bb-1c1112be423d", + "metadata": {}, + "outputs": [], + "source": [ + "examples = climate_dataset[\"test\"]\n", + "features = validation_features\n", + "\n", + "example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", + "features_per_example = collections.defaultdict(list)\n", + "for i, feature in enumerate(features):\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "ce38eda3-717d-430c-ab15-7b457598a19f", + "metadata": {}, + "outputs": [], + "source": [ + "def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):\n", + " all_start_logits, all_end_logits = raw_predictions\n", + " # Build a map example to its corresponding features.\n", + " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", + " features_per_example = collections.defaultdict(list)\n", + " for i, feature in enumerate(features):\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n", + "\n", + " # The dictionaries we have to fill.\n", + " predictions = collections.OrderedDict()\n", + "\n", + " # Logging.\n", + " print(f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\")\n", + "\n", + " # Let's loop over all the examples!\n", + " for example_index, example in enumerate(tqdm(examples)):\n", + " # Those are the indices of the features associated to the current example.\n", + " feature_indices = features_per_example[example_index]\n", + "\n", + " min_null_score = None # Only used if squad_v2 is True.\n", + " valid_answers = []\n", + "\n", + " context = example[\"context\"]\n", + " # Looping through all the features associated to the current example.\n", + " for feature_index in feature_indices:\n", + " # We grab the predictions of the model for this feature.\n", + " start_logits = all_start_logits[feature_index]\n", + " end_logits = all_end_logits[feature_index]\n", + " # This is what will allow us to map some the positions in our logits to span of texts in the original\n", + " # context.\n", + " offset_mapping = features[feature_index][\"offset_mapping\"]\n", + "\n", + " # Update minimum null prediction.\n", + " cls_index = features[feature_index][\"input_ids\"].index(tokenizer.cls_token_id)\n", + " feature_null_score = start_logits[cls_index] + end_logits[cls_index]\n", + " if min_null_score is None or min_null_score < feature_null_score:\n", + " min_null_score = feature_null_score\n", + "\n", + " # Go through all possibilities for the `n_best_size` greater start and end logits.\n", + " start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", + " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", + " for start_index in start_indexes:\n", + " for end_index in end_indexes:\n", + " # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond\n", + " # to part of the input_ids that are not in the context.\n", + " if (\n", + " start_index >= len(offset_mapping)\n", + " or end_index >= len(offset_mapping)\n", + " or offset_mapping[start_index] is None\n", + " or offset_mapping[end_index] is None\n", + " ):\n", + " continue\n", + " # Don't consider answers with a length that is either < 0 or > max_answer_length.\n", + " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n", + " continue\n", + "\n", + " start_char = offset_mapping[start_index][0]\n", + " end_char = offset_mapping[end_index][1]\n", + " valid_answers.append(\n", + " {\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\n", + " \"text\": context[start_char: end_char]\n", + " }\n", + " )\n", + "\n", + " if len(valid_answers) > 0:\n", + " best_answer = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[0]\n", + " else:\n", + " # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid\n", + " # failure.\n", + " best_answer = {\"text\": \"\", \"score\": 0.0}\n", + "\n", + " # Let's pick our final answer: the best one or the null answer (only for squad_v2)\n", + " answer = best_answer[\"text\"] if best_answer[\"score\"] > min_null_score else \"\"\n", + " predictions[example[\"id\"]] = answer\n", + "\n", + " return predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "74f9620d-bccc-4476-9d7d-bf9592925797", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Post-processing 16891 example predictions split into 16928 features.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b68483eb50994b2ba99fca6c349dbdb1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/16891 [00:00