Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example of training a tabular model on multiple GPUs #474

Merged
merged 15 commits into from
Dec 30, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added an example for training `Trompt` on multiple GPUs ([#474](https://github.com/pyg-team/pytorch-frame/pull/474))
- Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470))
- Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464))
- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444))
Expand Down
4 changes: 3 additions & 1 deletion benchmark/data_frame_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def train(
pred, y = model(tf, mixup_encoded=True)
elif isinstance(model, Trompt):
# Trompt uses the layer-wise loss
pred = model.forward_stacked(tf)
pred = model(tf)
num_layers = pred.size(1)
# [batch_size * num_layers, num_classes]
pred = pred.view(-1, out_channels)
Expand Down Expand Up @@ -294,6 +294,8 @@ def test(
for tf in loader:
tf = tf.to(device)
pred = model(tf)
if isinstance(model, Trompt):
pred = pred.mean(dim=1)
if dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
pred = pred.argmax(dim=-1)
elif dataset.task_type == TaskType.REGRESSION:
Expand Down
6 changes: 5 additions & 1 deletion benchmark/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def train(
y = tf.y
if isinstance(model, Trompt):
# Trompt uses the layer-wise loss
pred = model.forward_stacked(tf)
pred = model(tf)
num_layers = pred.size(1)
# [batch_size * num_layers, num_classes]
pred = pred.view(-1, out_channels)
Expand Down Expand Up @@ -337,6 +337,10 @@ def test(
for tf in loader:
tf = tf.to(device)
pred = model(tf)
if isinstance(model, Trompt):
# [batch_size, num_layers, out_channels]
# -> [batch_size, out_channels]
pred = pred.mean(dim=1)
if dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
pred = pred.argmax(dim=-1)
elif dataset.task_type == TaskType.REGRESSION:
Expand Down
4 changes: 2 additions & 2 deletions examples/trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def train(epoch: int) -> float:
for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"):
tf = tf.to(device)
# [batch_size, num_layers, num_classes]
out = model.forward_stacked(tf)
out = model(tf)
num_layers = out.size(1)
# [batch_size * num_layers, num_classes]
pred = out.view(-1, dataset.num_classes)
Expand All @@ -112,7 +112,7 @@ def test(loader: DataLoader) -> float:

for tf in loader:
tf = tf.to(device)
pred = model(tf)
pred = model(tf).mean(dim=1)
pred_class = pred.argmax(dim=-1)
accum += float((tf.y == pred_class).sum())
total_count += len(tf.y)
Expand Down
251 changes: 251 additions & 0 deletions examples/trompt_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import argparse
import logging
import os
import os.path as osp

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchmetrics
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from torch_frame.data import DataLoader
from torch_frame.datasets import TabularBenchmark
from torch_frame.nn import Trompt


def prepare_dataset(dataset_str: str) -> TabularBenchmark:
path = osp.join(
osp.dirname(osp.realpath(__file__)),
"..",
"data",
dataset_str,
)
materialized_path = osp.join(path, 'materialized_data.pt')
if dist.get_rank() == 0:
logging.info(f"Preparing dataset '{dataset_str}' from '{path}'")
dataset = TabularBenchmark(root=path, name=dataset_str)
logging.info("Materializing dataset")
dataset.materialize(path=materialized_path)

dist.barrier()
if dist.get_rank() != 0:
logging.info(f"Preparing dataset '{dataset_str}' from '{path}'")
dataset = TabularBenchmark(root=path, name=dataset_str)
logging.info("Loading materialized dataset")
dataset.materialize(path=materialized_path)

dist.barrier()
return dataset


def train(
model: DistributedDataParallel,
epoch: int,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
metric: torchmetrics.Metric,
rank: int,
) -> float:
model.train()
loss_accum = torch.tensor(0.0, device=rank, dtype=torch.float32)
for tf in tqdm(
loader,
desc=f"Epoch {epoch:03d} (train)",
disable=rank != 0,
):
tf = tf.to(rank)
# [batch_size, num_layers, num_classes]
out = model(tf)

with torch.no_grad():
metric.update(out.mean(dim=1).argmax(dim=-1), tf.y)

batch_size, num_layers, num_classes = out.size()
# [batch_size * num_layers, num_classes]
pred = out.view(-1, num_classes)
y = tf.y.repeat_interleave(
num_layers,
output_size=num_layers * batch_size,
)
# Layer-wise logit loss
loss = F.cross_entropy(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_accum += loss

dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
metric_value = metric.compute()
metric.reset()
return loss_accum, metric_value


@torch.no_grad()
def test(
model: DistributedDataParallel,
epoch: int,
loader: DataLoader,
metric: torchmetrics.Metric,
rank: int,
desc: str,
) -> float:
model.eval()
for tf in tqdm(
loader,
desc=f"Epoch {epoch:03d} ({desc})",
disable=rank != 0,
):
tf = tf.to(rank)
# [batch_size, num_layers, num_classes] -> [batch_size, num_classes]
pred = model(tf).mean(dim=1)
pred_class = pred.argmax(dim=-1)
metric.update(pred_class, tf.y)

metric_value = metric.compute()
metric.reset()
return metric_value


def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank,
)
logging.basicConfig(
format=(f"[rank={rank}/{world_size}] "
f"[%(asctime)s] %(levelname)s: %(message)s"),
level=logging.INFO,
)
logging.info(f"Initialized rank {rank}/{world_size}")
dataset = prepare_dataset(args.dataset)
assert dataset.task_type.is_classification

# Ensure train, val and test splits are the same across all ranks by
# setting the seed on each rank.
torch.manual_seed(args.seed)
dataset = dataset.shuffle()
train_dataset, val_dataset, test_dataset = (
dataset[:0.7],
dataset[0.7:0.79],
dataset[0.79:],
)
# Note that the last batch of evaluation loops is dropped for now because
# drop_last=False will duplicate samples to fill the last batch, leading to
# the wrong evaluation metrics.
# https://github.com/pytorch/pytorch/issues/25162
train_loader = DataLoader(
train_dataset.tensor_frame,
batch_size=args.batch_size,
sampler=DistributedSampler(
train_dataset,
shuffle=True,
drop_last=True,
),
)
val_loader = DataLoader(
val_dataset.tensor_frame,
batch_size=args.batch_size,
sampler=DistributedSampler(
val_dataset,
shuffle=True,
drop_last=True,
),
)
test_loader = DataLoader(
test_dataset.tensor_frame,
batch_size=args.batch_size,
sampler=DistributedSampler(
test_dataset,
shuffle=True,
drop_last=True,
),
)
model = Trompt(
channels=args.channels,
out_channels=dataset.num_classes,
num_prompts=args.num_prompts,
num_layers=args.num_layers,
col_stats=dataset.col_stats,
col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
model = torch.compile(model) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
metrics_kwargs = {
"task": "multiclass",
"num_classes": dataset.num_classes,
}
train_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank)
val_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank)
test_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank)
best_val_acc = 0.0
test_acc = 0.0
for epoch in range(1, args.epochs + 1):
train_loader.sampler.set_epoch(epoch)
train_loss, train_acc = train(
model,
epoch,
train_loader,
optimizer,
train_metric,
rank,
)
val_acc = test(
model,
epoch,
val_loader,
val_metric,
rank,
'val',
)
if best_val_acc < val_acc:
best_val_acc = val_acc
test_acc = test(
model,
epoch,
test_loader,
test_metric,
rank,
'test',
)
if rank == 0:
print(f"Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.4f}, "
f"Val Acc: {val_acc:.4f}")

lr_scheduler.step()

if rank == 0:
print(f"Best Val Acc: {best_val_acc:.4f}, "
f"Test Acc: {test_acc:.4f}")

dist.destroy_process_group()
logging.info("Process group destroyed")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="california")
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--num_prompts", type=int, default=128)
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

world_size = torch.cuda.device_count()
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
2 changes: 1 addition & 1 deletion test/nn/models/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
Trompt,
dict(channels=8, num_prompts=2),
None,
4,
3,
id="Trompt",
),
pytest.param(
Expand Down
4 changes: 1 addition & 3 deletions test/nn/models/test_trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,5 @@ def test_trompt(batch_size, use_stype_encoder_dicts):
stype_encoder_dicts=stype_encoder_dicts,
)
model.reset_parameters()
out = model.forward_stacked(tensor_frame)
assert out.shape == (batch_size, num_layers, out_channels)
pred = model(tensor_frame)
assert pred.shape == (batch_size, out_channels)
assert pred.shape == (batch_size, num_layers, out_channels)
5 changes: 1 addition & 4 deletions torch_frame/nn/models/trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def reset_parameters(self) -> None:
trompt_conv.reset_parameters()
self.trompt_decoder.reset_parameters()

def forward_stacked(self, tf: TensorFrame) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure this change does not break the example code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed the change doesn't break these scripts across all supported task types:

examples/trompt.py
benchmark/data_frame_benchmark.py
benchmark/data_frame_text_benchmark.py

def forward(self, tf: TensorFrame) -> Tensor:
r"""Transforming :class:`TensorFrame` object into a series of output
predictions at each layer. Used during training to compute layer-wise
loss.
Expand Down Expand Up @@ -152,6 +152,3 @@ def forward_stacked(self, tf: TensorFrame) -> Tensor:
# [batch_size, num_layers, out_channels]
stacked_out = torch.cat(outs, dim=1)
return stacked_out

def forward(self, tf: TensorFrame) -> Tensor:
return self.forward_stacked(tf).mean(dim=1)
Loading