From b29c031b9fea3d9fe8559b4d6084d4a327db451a Mon Sep 17 00:00:00 2001 From: Bobbins228 Date: Wed, 12 Jun 2024 11:02:17 +0100 Subject: [PATCH] Added DistributedSampler to the train_dataloader --- demo-notebooks/guided-demos/pytorch_lightning.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/demo-notebooks/guided-demos/pytorch_lightning.py b/demo-notebooks/guided-demos/pytorch_lightning.py index 0769f68b3..a1dd13caa 100644 --- a/demo-notebooks/guided-demos/pytorch_lightning.py +++ b/demo-notebooks/guided-demos/pytorch_lightning.py @@ -2,7 +2,7 @@ import tempfile import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler from torchvision.models import resnet18 from torchvision.datasets import FashionMNIST from torchvision.transforms import ToTensor, Normalize, Compose @@ -74,10 +74,19 @@ def train_func(): train_data = FashionMNIST( root=data_dir, train=True, download=True, transform=transform ) - train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True) # Training model = ImageClassifier() + + sampler = DistributedSampler( + train_data, + num_replicas=ray.train.get_context().get_world_size(), + rank=ray.train.get_context().get_world_rank(), + ) + + train_dataloader = DataLoader( + train_data, batch_size=128, shuffle=False, sampler=sampler + ) # [1] Configure PyTorch Lightning Trainer. trainer = pl.Trainer( max_epochs=10,