From cb1be8ef4cc964a1c5b5de9b3756c48ad93b20b9 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 14 Jan 2025 14:27:03 +0100 Subject: [PATCH] KEP-2170: Add PyTorch DDP Fashion MNIST training example Signed-off-by: Antonin Stefanutti --- examples/pytorch/mnist-ddp/README.md | 112 ++++++++ examples/pytorch/mnist-ddp/mnist.ipynb | 141 ++++++++++ examples/pytorch/mnist-ddp/mnist.py | 345 +++++++++++++++++++++++++ 3 files changed, 598 insertions(+) create mode 100644 examples/pytorch/mnist-ddp/README.md create mode 100644 examples/pytorch/mnist-ddp/mnist.ipynb create mode 100644 examples/pytorch/mnist-ddp/mnist.py diff --git a/examples/pytorch/mnist-ddp/README.md b/examples/pytorch/mnist-ddp/README.md new file mode 100644 index 0000000000..72b28c986d --- /dev/null +++ b/examples/pytorch/mnist-ddp/README.md @@ -0,0 +1,112 @@ +# PyTorch DDP Fashion MNIST Training Example + +This example demonstrates how to train a convolutional neural network to classify images +using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset +and [PyTorch DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). + +You can either run this example with the provided Jupyter notebook, +or by running the Python script directly. + +In any case, you need to install the Kubeflow training v2 control plane +on your Kubernetes cluster, if it's not already deployed: + +```console +kubectl apply --server-side -k "https://github.com/kubeflow/training-operator.git/manifests/v2/overlays/standalone?ref=master" +``` + +## Jupyter Notebook + +You can set up your environment by running the following commands: + +```console +python -m venv .venv +source .venv/bin/activate +pip install jupyter +``` + +And start the notebook by running: + +```console +jupyter notebook examples/pytorch/mnist-ddp/mnist.ipynb +``` + +You can then access the notebook from your Web browser and follow the instructions. + +## Python Script + +### Setup + +You need to set up the Python environment on your local machine or client: + +```console +python -m venv .venv +source .venv/bin/activate +pip install git+https://github.com/kubeflow/training-operator.git@master#subdirectory=sdk_v2 +``` + +You can refer to the [training operator documentation](https://www.kubeflow.org/docs/components/training/installation/) +for more information. + +### Usage + +```console +python mnist.py --help +usage: mnist.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] [--lr LR] [--lr-gamma G] [--lr-period P] [--seed S] [--log-interval N] [--save-model] + [--backend {gloo,nccl}] [--num-workers N] [--worker-resources RESOURCE QUANTITY] [--runtime NAME] + +PyTorch DDP Fashion MNIST Training Example + +options: + -h, --help show this help message and exit + --batch-size N input batch size for training [100] + --test-batch-size N input batch size for testing [100] + --epochs N number of epochs to train [10] + --lr LR learning rate [1e-1] + --lr-gamma G learning rate decay factor [0.5] + --lr-period P learning rate decay period in step size [20] + --seed S random seed [0] + --log-interval N how many batches to wait before logging training metrics [10] + --save-model saving the trained model [False] + --backend {gloo,nccl} + Distributed backend [nccl] + --num-workers N Number of workers [1] + --worker-resources RESOURCE QUANTITY + Resources per worker [cpu: 1, memory: 2Gi, nvidia.com/gpu: 1] + --runtime NAME the training runtime [torch-distributed] +``` + +### Example + +Train the model on 8 worker nodes using 1 NVIDIA GPU each: + +```console +python mnist.py \ + --num-workers 4 \ + --worker-resources "nvidia.com/gpu" 1 \ + --worker-resource cpu 4 \ + --worker-resources memory 16Gi \ + --epochs 100 \ + --batch-size 100 \ + --lr 1e-1 \ + --lr-period 20 \ + --lr-gamma 0.8 +``` + +At the end of each epoch, local metrics are printed in each worker logs and the global metrics +are gathered and printed in the rank 0 worker logs. + +When the training completes, you should see the following at the end of the rank 0 worker logs: + +```text +--------------- Epoch 50 Evaluation --------------- + +Local rank 0: +- Loss: 0.0040 +- Accuracy: 2255/2500 (90%) + +Global metrics: +- Loss: 0.004319 +- Accuracy: 9011/10000 (90.11%) + +--------------------------------------------------- +``` diff --git a/examples/pytorch/mnist-ddp/mnist.ipynb b/examples/pytorch/mnist-ddp/mnist.ipynb new file mode 100644 index 0000000000..1c80e087c8 --- /dev/null +++ b/examples/pytorch/mnist-ddp/mnist.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# PyTorch DDP Fashion MNIST Training Example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "This example demonstrates how to train a convolutional neural network to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Install the Kubeflow Training Python SDK\n", + "\n", + "You need to install the Kubeflow Training SDK to run this Notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the Kubeflow Training Client" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from kubeflow.training import Trainer, TrainingClient\n", + "from mnist import train_fashion_mnist" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "client = TrainingClient()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start the Train Job" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "job_name = client.train(\n", + " runtime_ref=\"torch-distributed\",\n", + " trainer=Trainer(\n", + " func=train_fashion_mnist,\n", + " func_args={\n", + " \"backend\": \"nccl\",\n", + " \"batch_size\": 100,\n", + " \"test_batch_size\": 100,\n", + " \"epochs\": 100,\n", + " \"lr\": 1e-1,\n", + " \"lr_gamma\": 0.95,\n", + " \"lr_period\": 20,\n", + " \"seed\": 0,\n", + " \"log_interval\": 10,\n", + " \"save_model\": False,\n", + " },\n", + " num_nodes=4,\n", + " resources_per_node={\n", + " \"nvidia.com/gpu\": 1,\n", + " },\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Watch the Train Job Logs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client.get_job_logs(job_name, follow=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/pytorch/mnist-ddp/mnist.py b/examples/pytorch/mnist-ddp/mnist.py new file mode 100644 index 0000000000..96e33946b6 --- /dev/null +++ b/examples/pytorch/mnist-ddp/mnist.py @@ -0,0 +1,345 @@ +import argparse + +from kubeflow.training import Trainer, TrainingClient + + +def train_fashion_mnist(dict): + + import os + + import torch + import torch.distributed as dist + import torch.nn as nn + import torch.nn.functional as F + import torchvision.transforms as transforms + from torch.nn.parallel import DistributedDataParallel + from torch.optim.lr_scheduler import StepLR + from torch.utils.data import DataLoader + from torch.utils.data.distributed import DistributedSampler + from torchvision.datasets import FashionMNIST + + backend = dict.get("backend") + batch_size = dict.get("batch_size") + test_batch_size = dict.get("test_batch_size") + epochs = dict.get("epochs") + lr = dict.get("lr") + lr_gamma = dict.get("lr_gamma") + lr_period = dict.get("lr_period") + seed = dict.get("seed") + log_interval = dict.get("log_interval") + save_model = dict.get("save_model") + + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(512, 512), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(512, 10), + nn.ReLU(), + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + + def train(model, device, criterion, train_loader, optimizer, epoch, log_interval): + # Enter training mode + model.train() + # Iterate over mini-batches from the training set + for batch_idx, (inputs, labels) in enumerate(train_loader): + # Copy the data to the GPU device if available + inputs, labels = inputs.to(device), labels.to(device) + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch_idx % log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(inputs), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + def test(model, device, criterion, rank, test_loader, epoch): + # Enter evaluation mode + model.eval() + samples = 0 + local_loss = 0 + local_correct = 0 + # Disable gradient computation to speed up the computation + # and reduce memory usage + with torch.no_grad(): + # Iterate over mini-batches from the evaluation set + for inputs, labels in test_loader: + samples += len(inputs) + # Copy the data to the GPU device if available + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + # Sum up batch loss + local_loss += criterion(outputs, labels).item() + # Get the index of the max log-probability + pred = outputs.argmax(dim=1, keepdim=True) + local_correct += pred.eq(labels.view_as(pred)).sum().item() + + local_accuracy = 100.0 * local_correct / samples + + header = f"{'-'*15} Epoch {epoch} Evaluation {'-'*15}" + print(f"\n{header}\n") + # Log local metrics on each rank + print(f"Local rank {rank}:") + print(f"- Loss: {local_loss / samples:.4f}") + print(f"- Accuracy: {local_correct}/{samples} ({local_accuracy:.0f}%)\n") + + # To Tensors so local metrics can be globally reduced across ranks + global_loss = torch.tensor([local_loss], device=device) + global_correct = torch.tensor([local_correct], device=device) + + # Reduce the metrics on rank 0 + dist.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) + dist.reduce(global_correct, dst=0, op=torch.distributed.ReduceOp.SUM) + + # Log the aggregated metrics only on rank 0 + if rank == 0: + global_loss = global_loss / len(test_loader.dataset) + global_accuracy = (global_correct.double() / len(test_loader.dataset)) * 100 + global_correct = global_correct.int().item() + samples = len(test_loader.dataset) + print("Global metrics:") + print(f"- Loss: {global_loss.item():.6f}") + print( + f"- Accuracy: {global_correct}/{samples} ({global_accuracy.item():.2f}%)" + ) + else: + print("See rank 0 logs for global metrics") + print(f"\n{'-'*len(header)}\n") + + dist.init_process_group(backend=backend) + + torch.manual_seed(seed) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + model = Net() + + use_cuda = torch.cuda.is_available() + if use_cuda: + if backend != torch.distributed.Backend.NCCL: + print( + "Please use NCCL distributed backend for the best performance using NVIDIA GPUs" + ) + device = torch.device(f"cuda:{local_rank}") + model = DistributedDataParallel(model.to(device), device_ids=[local_rank]) + else: + device = torch.device("cpu") + model = DistributedDataParallel(model.to(device)) + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] + ) + + # Create datasets and data loaders for training & validation, download if necessary + training_set = FashionMNIST( + "./data", train=True, transform=transform, download=True + ) + training_sampler = DistributedSampler(training_set) + training_loader = DataLoader( + dataset=training_set, + batch_size=batch_size, + sampler=training_sampler, + pin_memory=use_cuda, + ) + + validation_set = FashionMNIST( + "./data", train=False, transform=transform, download=True + ) + validation_sampler = DistributedSampler(validation_set) + validation_loader = DataLoader( + dataset=validation_set, + batch_size=test_batch_size, + sampler=validation_sampler, + pin_memory=use_cuda, + ) + + # Report dataset sizes + print("Training set has {} instances".format(len(training_set))) + print("Validation set has {} instances\n".format(len(validation_set))) + + criterion = nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + scheduler = StepLR(optimizer, step_size=lr_period, gamma=lr_gamma) + + for epoch in range(1, epochs + 1): + train(model, device, criterion, training_loader, optimizer, epoch, log_interval) + test(model, device, criterion, rank, validation_loader, epoch) + scheduler.step() + + if save_model: + torch.save(model.state_dict(), "mnist.pt") + + # Wait so rank 0 can gather the global metrics + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="PyTorch DDP Fashion MNIST Training Example" + ) + + parser.add_argument( + "--batch-size", + type=int, + default=100, + metavar="N", + help="input batch size for training [100]", + ) + + parser.add_argument( + "--test-batch-size", + type=int, + default=100, + metavar="N", + help="input batch size for testing [100]", + ) + + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train [10]", + ) + + parser.add_argument( + "--lr", + type=float, + default=1e-1, + metavar="LR", + help="learning rate [1e-1]", + ) + + parser.add_argument( + "--lr-gamma", + type=float, + default=0.5, + metavar="G", + help="learning rate decay factor [0.5]", + ) + + parser.add_argument( + "--lr-period", + type=float, + default=20, + metavar="P", + help="learning rate decay period in step size [20]", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + metavar="S", + help="random seed [0]", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training metrics [10]", + ) + + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="saving the trained model [False]", + ) + + parser.add_argument( + "--backend", + type=str, + choices=["gloo", "nccl"], + default="nccl", + help="Distributed backend [nccl]", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=1, + metavar="N", + help="Number of workers [1]", + ) + + parser.add_argument( + "--worker-resources", + type=str, + nargs=2, + action="append", + dest="resources", + default=[ + ("cpu", 1), + ("memory", "2Gi"), + ("nvidia.com/gpu", 1), + ], + metavar=("RESOURCE", "QUANTITY"), + help="Resources per worker [cpu: 1, memory: 2Gi, nvidia.com/gpu: 1]", + ) + + parser.add_argument( + "--runtime", + type=str, + default="torch-distributed", + metavar="NAME", + help="the training runtime [torch-distributed]", + ) + + args = parser.parse_args() + + client = TrainingClient() + + job_name = client.train( + runtime_ref=args.runtime, + trainer=Trainer( + func=train_fashion_mnist, + func_args={ + "backend": args.backend, + "batch_size": args.batch_size, + "test_batch_size": args.test_batch_size, + "epochs": args.epochs, + "lr": args.lr, + "lr_gamma": args.lr_gamma, + "lr_period": args.lr_period, + "seed": args.seed, + "log_interval": args.log_interval, + "save_model": args.save_model, + }, + num_nodes=args.num_workers, + resources_per_node={ + resource: quantity for (resource, quantity) in args.resources + }, + ), + ) + + client.get_job_logs(job_name, follow=True)