Skip to content

Commit

Permalink
Subclass hydra configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 11, 2024
1 parent a81b206 commit ae254fb
Show file tree
Hide file tree
Showing 23 changed files with 110 additions and 369 deletions.
1 change: 1 addition & 0 deletions config/callbacks/checkpoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions config/callbacks/early_stop/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 4 additions & 0 deletions config/callbacks/lr_monitor/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: False
1 change: 1 addition & 0 deletions config/callbacks/richsummary/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: pytorch_lightning.callbacks.RichModelSummary
5 changes: 5 additions & 0 deletions config/dataloader/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: torch.utils.data.DataLoader
batch_size: 4
num_workers: 2
pin_memory: True
dataset: ${dataset}
4 changes: 4 additions & 0 deletions config/logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: ${model.name}
7 changes: 7 additions & 0 deletions config/logger/wandb_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: M ${dataset.mask_idx} | ${misc.seed}
tags:
- model=${model.name}
- MASK
11 changes: 11 additions & 0 deletions config/metrics/binary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
acc:
_target_: torchmetrics.Accuracy
task: binary
average: weighted
f1:
_target_: torchmetrics.F1Score
task: binary
average: macro
auroc:
_target_: torchmetrics.AUROC
task: binary
14 changes: 14 additions & 0 deletions config/metrics/cls.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
acc:
_target_: torchmetrics.Accuracy
task: multiclass
num_classes: ${model.backbone.num_classes}
average: weighted
f1:
_target_: torchmetrics.F1Score
task: multiclass
num_classes: ${model.backbone.num_classes}
average: macro
auroc:
_target_: torchmetrics.AUROC
task: multiclass
num_classes: ${model.backbone.num_classes}
9 changes: 9 additions & 0 deletions config/metrics/reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mae:
_target_: torchmetrics.MeanAbsoluteError
rmse:
_target_: torchmetrics.MeanSquaredError
squared: False
corr:
_target_: torchmetrics.PearsonCorrCoef
r2:
_target_: torchmetrics.R2Score
3 changes: 3 additions & 0 deletions config/misc/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
seed: 42
debug: False
modes: [ train, valid ]
9 changes: 9 additions & 0 deletions config/module/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: sage.trainer.PLModule
_recursive_: False
augmentation:
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
log_train_metrics: True
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}
106 changes: 10 additions & 96 deletions config/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,104 +10,18 @@ defaults:
- dataset: biobank
- scheduler: cosine_anneal_warmup
- optim: lion

dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 32
num_workers: 8
pin_memory: True
dataset: ${dataset}

misc:
seed: 42
debug: False
modes: [ train, valid, test ]

module:
_target_: sage.trainer.PLModule
_recursive_: False
augmentation:
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}

metrics:
mae:
_target_: torchmetrics.MeanAbsoluteError
rmse:
_target_: torchmetrics.MeanSquaredError
squared: False
# corr:
# _target_: torchmetrics.PearsonCorrCoef
r2:
_target_: torchmetrics.R2Score

logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: R ${model.name} | ${misc.seed}
tags:
- model=${model.name}
- REG

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 500
devices: 1
accelerator: gpu
gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 1
# DEBUGGING FLAGS. TODO: Split
# limit_train_batches: 0.001
# limit_val_batches: 0.01
- trainer: default
- dataloader: default
- metrics: reg
- logger: wandb
- module: default
- callbacks/checkpoint: reg
- callbacks/early_stop: reg
- callbacks/lr_monitor: default
- callbacks/richsummary: default
- misc: default

callbacks:
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra:run.dir}
filename: "{step}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}"
monitor: epoch/valid_MeanAbsoluteError
mode: min
save_top_k: 1
save_last: True
# Is useful to set it to False when metric names contain / as this will result in extra folders
auto_insert_metric_name: False

# checkpoint_lr:
# _target_: pytorch_lightning.callbacks.ModelCheckpoint
# dirpath: ${hydra:run.dir}
# filename: "ckpt-step{step}-lr{_lr:.2e}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}"
# monitor: _lr
# save_top_k: 7
# mode: min
# auto_insert_metric_name: False

# https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html
early_stop:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_MeanSquaredError
mode: min
min_delta: 0.005
patience: 20

# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html
lr_monitor:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: False

richsummary:
_target_: pytorch_lightning.callbacks.RichModelSummary

richpbar:
_target_: pytorch_lightning.callbacks.RichProgressBar

manual_ckpt:
_target_: sage.trainer.callbacks.AnchorCheckpoint
multiplier: 1
97 changes: 10 additions & 87 deletions config/train_binary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,13 @@ defaults:
- dataset: ppmi
- scheduler: exp_decay
- optim: adamw

dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 4
num_workers: 2
pin_memory: True
dataset: ${dataset}

misc:
seed: 42
debug: False
modes: [ train, valid ]

module:
_target_: sage.trainer.PLModule
_recursive_: False
augmentation:
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
log_train_metrics: True
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}

metrics:
acc:
_target_: torchmetrics.Accuracy
task: binary
average: weighted
f1:
_target_: torchmetrics.F1Score
task: binary
average: macro
auroc:
_target_: torchmetrics.AUROC
task: binary

logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: C ${model.name}

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 300
devices: 1
accelerator: gpu
gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 4
# DEBUGGING FLAGS. TODO: Split
# limit_train_batches: 0.001
# limit_val_batches: 0.01

callbacks:
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra.run.dir}
filename: "{step}-valid_f1-{epoch/valid_BinaryF1Score:.3f}"
monitor: epoch/valid_BinaryF1Score
mode: max
save_top_k: 1
save_last: True
# Is useful to set it to False when metric names contain / as this will result in extra folders
auto_insert_metric_name: False

# https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html
early_stop:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_BinaryF1Score
mode: max
patience: 10

# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html
lr_monitor:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: False

richsummary:
_target_: pytorch_lightning.callbacks.RichModelSummary

richpbar:
_target_: pytorch_lightning.callbacks.RichProgressBar
- trainer: default
- dataloader: default
- metrics: binary
- logger: wandb
- module: default
- callbacks/checkpoint: binary
- callbacks/early_stop: binary
- callbacks/lr_monitor: default
- callbacks/richsummary: default
- misc: default
Loading

0 comments on commit ae254fb

Please sign in to comment.