Skip to content

Commit

Permalink
Found hydra to change the target yaml file
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Dec 18, 2023
1 parent 508283e commit 673796d
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 16 deletions.
8 changes: 8 additions & 0 deletions config/dataset/biobank_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_target_: sage.data.mask.UKB_MaskDataset
atlas_name: dkt
mask_idx:
root: biobank
label_name:
mode: train
valid_ratio: 0.1
seed: ${misc.seed}
1 change: 1 addition & 0 deletions config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ trainer:
accelerator: gpu
gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 1
# DEBUGGING FLAGS. TODO: Split
# limit_train_batches: 0.001
# limit_val_batches: 0.01
Expand Down
114 changes: 114 additions & 0 deletions config/train_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
hydra:
job:
chdir: False # keep hydra = 1.1 change directory behavior
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${logger.name}

defaults:
- _self_
- model: resnet
- dataset: biobank_mask
- scheduler: cosine_anneal_warmup
- optim: lion

dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 32
num_workers: 4
pin_memory: True
dataset: ${dataset}

misc:
seed: 42
debug: False
modes: [ train, valid, test ]

module:
_target_: sage.trainer.PLModule
_recursive_: False
augmentation:
_target_: sage.data.augment
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}

metrics:
mae:
_target_: torchmetrics.MeanAbsoluteError
rmse:
_target_: torchmetrics.MeanSquaredError
squared: False
# corr:
# _target_: torchmetrics.PearsonCorrCoef
r2:
_target_: torchmetrics.R2Score

logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: M ${dataset.mask_idx} | ${misc.seed}
tags:
- model=${model.name}
- REG

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 500
devices: 1
accelerator: gpu
gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 1
# DEBUGGING FLAGS. TODO: Split
# limit_train_batches: 0.001
# limit_val_batches: 0.01

callbacks:
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra:run.dir}
filename: "{step}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}"
monitor: epoch/valid_MeanAbsoluteError
mode: min
save_top_k: 1
save_last: True
# Is useful to set it to False when metric names contain / as this will result in extra folders
auto_insert_metric_name: False

# checkpoint_lr:
# _target_: pytorch_lightning.callbacks.ModelCheckpoint
# dirpath: ${hydra:run.dir}
# filename: "ckpt-step{step}-lr{_lr:.2e}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}"
# monitor: _lr
# save_top_k: 7
# mode: min
# auto_insert_metric_name: False

# https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html
early_stop:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_MeanSquaredError
mode: min
min_delta: 0.005
patience: 20

# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html
lr_monitor:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: False

richsummary:
_target_: pytorch_lightning.callbacks.RichModelSummary

richpbar:
_target_: pytorch_lightning.callbacks.RichProgressBar

manual_ckpt:
_target_: sage.trainer.callbacks.AnchorCheckpoint
multiplier: 1
16 changes: 0 additions & 16 deletions train_cls.py

This file was deleted.

0 comments on commit 673796d

Please sign in to comment.