From 66ebecd12e689ffccc5b050560bd370c1bcc125d Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Tue, 22 Oct 2024 13:18:28 +0100 Subject: [PATCH] FSDP Example for T5 Fine-Tuning and PyTorchJob (#2286) * FSDP Example with PyTorchJob and T5 Fine-Tuning Signed-off-by: Andrey Velichkevich * Modify text Signed-off-by: Andrey Velichkevich --------- Signed-off-by: Andrey Velichkevich --- .../pytorch/fsdp/fine-tune-t5-with-fsdp.ipynb | 509 ++++++++++++++++++ 1 file changed, 509 insertions(+) create mode 100644 examples/pytorch/fsdp/fine-tune-t5-with-fsdp.ipynb diff --git a/examples/pytorch/fsdp/fine-tune-t5-with-fsdp.ipynb b/examples/pytorch/fsdp/fine-tune-t5-with-fsdp.ipynb new file mode 100644 index 0000000000..cc1fab15bc --- /dev/null +++ b/examples/pytorch/fsdp/fine-tune-t5-with-fsdp.ipynb @@ -0,0 +1,509 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-Tune T5 Model with PyTorchJob and FSDP" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Notebook will fine-tune Text-to-Text Transfer Transformer (T5) with Wikihow dataset for text summarization using Kubeflow PyTorchJob.\n", + "\n", + "Pretrained T5 model: https://huggingface.co/google-t5/t5-base\n", + "\n", + "Wikihow dataset: https://github.com/mahnazkoupaee/WikiHow-Dataset\n", + "\n", + "This Notebook will use **4** GPUs to fine-tune T5 model on 2 Nodes. This example is based on [the official PyTorch FSDP tutorial](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FSDP with multi-node multi-worker training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Notebook demonstrates multi-node, multi-worker distributed training with Fully Sharded Data Parallel (FSDP) and PyTorchJob.\n", + "\n", + "When a model is trained with FSDP, the GPU memory footprint is smaller compare to Distributed Data Parallel (DDP),\n", + "as the model parameters are sharded across GPU devices.\n", + "\n", + "This enables training of very large models that would otherwise be impossible to fit on a single GPU device.\n", + "\n", + "Check this guide to learn more about PyTorch FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Install the required packages\n", + "\n", + "Install the Kubeflow Training Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# TODO (andreyvelich): Use the release version of SDK.\n", + "!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create script to fine-tune T5 using FSDP\n", + "\n", + "We need to wrap our fine-tuning script in a function to create Kubeflow PyTorchJob." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def train_function(parameters):\n", + " import os\n", + " import time\n", + " import functools\n", + "\n", + " import torch\n", + " import torch.distributed as dist\n", + " from torch.utils.data.distributed import DistributedSampler\n", + " from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n", + " from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\n", + "\n", + " from transformers import T5Tokenizer, T5ForConditionalGeneration\n", + " from transformers.models.t5.modeling_t5 import T5Block\n", + " from datasets import Dataset\n", + "\n", + " # [1] Setup PyTorch distributed and get the distributed parameters.\n", + " dist.init_process_group(\"nccl\")\n", + " local_rank = int(os.environ[\"LOCAL_RANK\"])\n", + " rank = dist.get_rank()\n", + " world_size = dist.get_world_size()\n", + "\n", + " # Local rank identifies the GPU number inside the pod.\n", + " torch.cuda.set_device(local_rank)\n", + "\n", + " print(\n", + " f\"FSDP Training for WORLD_SIZE: {world_size}, RANK: {rank}, LOCAL_RANK: {local_rank}\"\n", + " )\n", + "\n", + " # [2] Prepare the Wikihow dataset\n", + " class wikihow(torch.utils.data.Dataset):\n", + " def __init__(\n", + " self,\n", + " tokenizer,\n", + " num_samples,\n", + " input_length,\n", + " output_length,\n", + " ):\n", + "\n", + " self.dataset = Dataset.from_csv(parameters[\"DATASET_URL\"])\n", + " self.dataset = self.dataset.select(list(range(0, num_samples)))\n", + " self.input_length = input_length\n", + " self.tokenizer = tokenizer\n", + " self.output_length = output_length\n", + "\n", + " def __len__(self):\n", + " return self.dataset.shape[0]\n", + "\n", + " def clean_text(self, text):\n", + " # Dataset contains empty values.\n", + " if text is None:\n", + " return \"\"\n", + " text = text.replace(\"Example of text:\", \"\")\n", + " text = text.replace(\"Example of Summary:\", \"\")\n", + " text = text.replace(\"\\n\", \"\")\n", + " text = text.replace(\"``\", \"\")\n", + " text = text.replace('\"', \"\")\n", + "\n", + " return text\n", + "\n", + " def convert_to_features(self, example_batch):\n", + " # Tokenize text and headline (as pairs of inputs).\n", + " input_ = self.clean_text(example_batch[\"text\"])\n", + " target_ = self.clean_text(example_batch[\"headline\"])\n", + "\n", + " source = self.tokenizer.batch_encode_plus(\n", + " [input_],\n", + " max_length=self.input_length,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " return_tensors=\"pt\",\n", + " )\n", + "\n", + " targets = self.tokenizer.batch_encode_plus(\n", + " [target_],\n", + " max_length=self.output_length,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " return_tensors=\"pt\",\n", + " )\n", + "\n", + " return source, targets\n", + "\n", + " def __getitem__(self, index):\n", + " source, targets = self.convert_to_features(self.dataset[index])\n", + "\n", + " source_ids = source[\"input_ids\"].squeeze()\n", + " target_ids = targets[\"input_ids\"].squeeze()\n", + "\n", + " src_mask = source[\"attention_mask\"].squeeze()\n", + " target_mask = targets[\"attention_mask\"].squeeze()\n", + "\n", + " return {\n", + " \"source_ids\": source_ids,\n", + " \"source_mask\": src_mask,\n", + " \"target_ids\": target_ids,\n", + " \"target_mask\": target_mask,\n", + " }\n", + "\n", + " # [3] Get the T5 pre-trained model and tokenizer.\n", + " # Since this script is run by multiple workers, we should print results only for the worker with RANK=0.\n", + " if rank == 0:\n", + " print(f\"Downloading the {parameters['MODEL_NAME']} model\")\n", + "\n", + " model = T5ForConditionalGeneration.from_pretrained(parameters[\"MODEL_NAME\"])\n", + " tokenizer = T5Tokenizer.from_pretrained(parameters[\"MODEL_NAME\"])\n", + "\n", + " # [4] Download the Wikihow dataset.\n", + " if rank == 0:\n", + " print(\"Downloading the Wikihow dataset\")\n", + "\n", + " dataset = wikihow(tokenizer, 1500, 512, 150)\n", + " train_loader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=4,\n", + " sampler=DistributedSampler(dataset),\n", + " )\n", + "\n", + " # [5] Setup model with FSDP.\n", + " # Model is on CPU before input to FSDP.\n", + " t5_auto_wrap_policy = functools.partial(\n", + " transformer_auto_wrap_policy,\n", + " transformer_layer_cls={\n", + " T5Block,\n", + " },\n", + " )\n", + " model = FSDP(\n", + " model,\n", + " auto_wrap_policy=t5_auto_wrap_policy,\n", + " device_id=torch.cuda.current_device(),\n", + " )\n", + "\n", + " # [6] Start training.\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)\n", + " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)\n", + " t0 = time.time()\n", + " if rank == 0:\n", + " print(\"Training is started...\")\n", + "\n", + " for epoch in range(1, 3):\n", + " model.train()\n", + " fsdp_loss = torch.zeros(2).to(local_rank)\n", + "\n", + " for batch in train_loader:\n", + " for key in batch.keys():\n", + " batch[key] = batch[key].to(local_rank)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(\n", + " input_ids=batch[\"source_ids\"],\n", + " attention_mask=batch[\"source_mask\"],\n", + " labels=batch[\"target_ids\"],\n", + " )\n", + " loss = output[\"loss\"]\n", + " loss.backward()\n", + " optimizer.step()\n", + " fsdp_loss[0] += loss.item()\n", + " fsdp_loss[1] += len(batch)\n", + "\n", + " dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)\n", + " train_accuracy = fsdp_loss[0] / fsdp_loss[1]\n", + "\n", + " if rank == 0:\n", + " print(f\"Train Epoch: \\t{epoch}, Loss: \\t{train_accuracy:.4f}\")\n", + "\n", + " scheduler.step()\n", + "\n", + " dist.barrier()\n", + "\n", + " if rank == 0:\n", + " print(f\"FSDP training time: {int(time.time() - t0)} seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Kubeflow PyTorchJob to fine-tune T5 with FSDP\n", + "\n", + "Use `TrainingClient()` to create PyTorchJob which will fine-tine T5 on **2 workers** using **2 GPU** for each worker.\n", + "\n", + "If you don't have enough GPU resources, you can decrease number of workers or number of GPUs per worker." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from kubeflow.training import TrainingClient\n", + "\n", + "job_name = \"fsdp-fine-tuning\"\n", + "\n", + "parameters = {\n", + " \"DATASET_URL\": \"https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowAll.csv\",\n", + " \"MODEL_NAME\": \"t5-base\",\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Create the PyTorchJob.\n", + "TrainingClient().create_job(\n", + " name=job_name,\n", + " train_func=train_function,\n", + " parameters=parameters,\n", + " num_workers=2, # You can modify number of workers or number of GPUs.\n", + " num_procs_per_worker=2,\n", + " resources_per_worker={\"gpu\": 2},\n", + " packages_to_install=[\n", + " \"transformers==4.38.2\",\n", + " \"datasets==2.21.0\",\n", + " \"SentencePiece==0.2.0\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "### Check the PyTorchJob conditions\n", + "\n", + "Use `TrainingClient()` APIs to get information about created PyTorchJob." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorchJob Conditions\n", + "[{'last_transition_time': datetime.datetime(2024, 10, 16, 19, 24, 17, tzinfo=tzutc()),\n", + " 'last_update_time': datetime.datetime(2024, 10, 16, 19, 24, 17, tzinfo=tzutc()),\n", + " 'message': 'PyTorchJob fsdp-fine-tuning is created.',\n", + " 'reason': 'PyTorchJobCreated',\n", + " 'status': 'True',\n", + " 'type': 'Created'}, {'last_transition_time': datetime.datetime(2024, 10, 16, 19, 24, 18, tzinfo=tzutc()),\n", + " 'last_update_time': datetime.datetime(2024, 10, 16, 19, 24, 18, tzinfo=tzutc()),\n", + " 'message': 'PyTorchJob fsdp-fine-tuning is running.',\n", + " 'reason': 'PyTorchJobRunning',\n", + " 'status': 'True',\n", + " 'type': 'Running'}]\n", + "----------------------------------------\n", + "PyTorchJob is running\n" + ] + } + ], + "source": [ + "print(\"PyTorchJob Conditions\")\n", + "print(TrainingClient().get_job_conditions(job_name))\n", + "print(\"-\" * 40)\n", + "\n", + "# Wait until PyTorchJob has the Running condition.\n", + "job = TrainingClient().wait_for_job_conditions(\n", + " job_name,\n", + " expected_conditions={\"Running\"},\n", + ")\n", + "print(\"PyTorchJob is running\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get the PyTorchJob pod names\n", + "\n", + "Since we define 2 workers, PyTorchJob will create 1 master pod and 1 worker pod to run FSDP fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['fsdp-fine-tuning-master-0', 'fsdp-fine-tuning-worker-0']" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TrainingClient().get_job_pod_names(job_name)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": { + "iopub.status.busy": "2022-09-01T20:10:25.759950Z", + "iopub.status.idle": "2022-09-01T20:10:25.760581Z", + "shell.execute_reply": "2022-09-01T20:10:25.760353Z", + "shell.execute_reply.started": "2022-09-01T20:10:25.760328Z" + }, + "tags": [] + }, + "source": [ + "### Get the PyTorchJob training logs\n", + "\n", + "Model parameters are sharded across all workers and GPU devices." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Pod fsdp-fine-tuning-master-0]: WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n", + "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] \n", + "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] *****************************************\n", + "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] *****************************************\n", + "[Pod fsdp-fine-tuning-master-0]: FSDP Training for WORLD_SIZE: 4, RANK: 0, LOCAL_RANK: 0\n", + "[Pod fsdp-fine-tuning-master-0]: Downloading the t5-base model\n", + "[Pod fsdp-fine-tuning-master-0]: /opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + "[Pod fsdp-fine-tuning-master-0]: warnings.warn(\n", + "[Pod fsdp-fine-tuning-master-0]: FSDP Training for WORLD_SIZE: 4, RANK: 1, LOCAL_RANK: 1\n", + "[Pod fsdp-fine-tuning-master-0]: /opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + "[Pod fsdp-fine-tuning-master-0]: warnings.warn(\n", + "[Pod fsdp-fine-tuning-master-0]: You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", + "[Pod fsdp-fine-tuning-master-0]: You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", + "[Pod fsdp-fine-tuning-master-0]: Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "[Pod fsdp-fine-tuning-master-0]: Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "[Pod fsdp-fine-tuning-master-0]: Downloading the Wikihow dataset\n", + "Downloading data: 100%|██████████| 619M/619M [00:11<00:00, 55.0MB/s] \n", + "Generating train split: 215365 examples [00:08, 26087.53 examples/s]\n", + "[Pod fsdp-fine-tuning-master-0]: Training is started...\n", + "[Pod fsdp-fine-tuning-master-0]: Train Epoch: \t1, Loss: \t0.3802\n", + "[Pod fsdp-fine-tuning-master-0]: Train Epoch: \t2, Loss: \t0.2659\n", + "[Pod fsdp-fine-tuning-master-0]: FSDP training time: 107 seconds\n" + ] + } + ], + "source": [ + "logs, _ = TrainingClient().get_job_logs(job_name, follow=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-01T23:44:15.511173Z", + "iopub.status.busy": "2024-03-01T23:44:15.510932Z", + "iopub.status.idle": "2024-03-01T23:44:15.539921Z", + "shell.execute_reply": "2024-03-01T23:44:15.539352Z", + "shell.execute_reply.started": "2024-03-01T23:44:15.511155Z" + }, + "tags": [] + }, + "source": [ + "## Delete the PyTorchJob\n", + "\n", + "You can delete the created PyTorchJob." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TrainingClient().delete_job(name=job_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}