Skip to content

Commit

Permalink
Added DistributedSampler to the train_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobbins228 committed Jun 12, 2024
1 parent 3defc67 commit b29c031
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions demo-notebooks/guided-demos/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
Expand Down Expand Up @@ -74,10 +74,19 @@ def train_func():
train_data = FashionMNIST(
root=data_dir, train=True, download=True, transform=transform
)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
model = ImageClassifier()

sampler = DistributedSampler(
train_data,
num_replicas=ray.train.get_context().get_world_size(),
rank=ray.train.get_context().get_world_rank(),
)

train_dataloader = DataLoader(
train_data, batch_size=128, shuffle=False, sampler=sampler
)
# [1] Configure PyTorch Lightning Trainer.
trainer = pl.Trainer(
max_epochs=10,
Expand Down

0 comments on commit b29c031

Please sign in to comment.