From ae254fb06338c13beed15602702d1860ca3e85b9 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Mon, 11 Mar 2024 08:17:36 +0000 Subject: [PATCH] Subclass hydra configurations --- config/callbacks/checkpoint/README.md | 1 + .../binary.yaml} | 0 .../cls.yaml} | 0 .../reg.yaml} | 0 config/callbacks/early_stop/README.md | 1 + .../binary.yaml} | 0 .../cls.yaml} | 0 .../reg.yaml} | 0 config/callbacks/lr_monitor/default.yaml | 4 + config/callbacks/richsummary/default.yaml | 1 + config/dataloader/default.yaml | 5 + config/logger/wandb.yaml | 4 + config/logger/wandb_mask.yaml | 7 ++ config/metrics/binary.yaml | 11 ++ config/metrics/cls.yaml | 14 +++ config/metrics/reg.yaml | 9 ++ config/misc/default.yaml | 3 + config/module/default.yaml | 9 ++ config/train.yaml | 106 ++---------------- config/train_binary.yaml | 97 ++-------------- config/train_cls.yaml | 100 ++--------------- config/train_mask.yaml | 106 ++---------------- config/trainer/default.yaml | 1 + 23 files changed, 110 insertions(+), 369 deletions(-) create mode 100644 config/callbacks/checkpoint/README.md rename config/callbacks/{checkpoint_binary.yaml => checkpoint/binary.yaml} (100%) rename config/callbacks/{checkpoint_cls.yaml => checkpoint/cls.yaml} (100%) rename config/callbacks/{checkpoint_reg.yaml => checkpoint/reg.yaml} (100%) create mode 100644 config/callbacks/early_stop/README.md rename config/callbacks/{early_stop_binary.yaml => early_stop/binary.yaml} (100%) rename config/callbacks/{early_stop_cls.yaml => early_stop/cls.yaml} (100%) rename config/callbacks/{early_stop_reg.yaml => early_stop/reg.yaml} (100%) create mode 100644 config/callbacks/lr_monitor/default.yaml create mode 100644 config/callbacks/richsummary/default.yaml create mode 100644 config/dataloader/default.yaml create mode 100644 config/logger/wandb.yaml create mode 100644 config/logger/wandb_mask.yaml create mode 100644 config/metrics/binary.yaml create mode 100644 config/metrics/cls.yaml create mode 100644 config/metrics/reg.yaml create mode 100644 config/misc/default.yaml create mode 100644 config/module/default.yaml diff --git a/config/callbacks/checkpoint/README.md b/config/callbacks/checkpoint/README.md new file mode 100644 index 0000000..67084df --- /dev/null +++ b/config/callbacks/checkpoint/README.md @@ -0,0 +1 @@ + https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html diff --git a/config/callbacks/checkpoint_binary.yaml b/config/callbacks/checkpoint/binary.yaml similarity index 100% rename from config/callbacks/checkpoint_binary.yaml rename to config/callbacks/checkpoint/binary.yaml diff --git a/config/callbacks/checkpoint_cls.yaml b/config/callbacks/checkpoint/cls.yaml similarity index 100% rename from config/callbacks/checkpoint_cls.yaml rename to config/callbacks/checkpoint/cls.yaml diff --git a/config/callbacks/checkpoint_reg.yaml b/config/callbacks/checkpoint/reg.yaml similarity index 100% rename from config/callbacks/checkpoint_reg.yaml rename to config/callbacks/checkpoint/reg.yaml diff --git a/config/callbacks/early_stop/README.md b/config/callbacks/early_stop/README.md new file mode 100644 index 0000000..45bf26c --- /dev/null +++ b/config/callbacks/early_stop/README.md @@ -0,0 +1 @@ +https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html \ No newline at end of file diff --git a/config/callbacks/early_stop_binary.yaml b/config/callbacks/early_stop/binary.yaml similarity index 100% rename from config/callbacks/early_stop_binary.yaml rename to config/callbacks/early_stop/binary.yaml diff --git a/config/callbacks/early_stop_cls.yaml b/config/callbacks/early_stop/cls.yaml similarity index 100% rename from config/callbacks/early_stop_cls.yaml rename to config/callbacks/early_stop/cls.yaml diff --git a/config/callbacks/early_stop_reg.yaml b/config/callbacks/early_stop/reg.yaml similarity index 100% rename from config/callbacks/early_stop_reg.yaml rename to config/callbacks/early_stop/reg.yaml diff --git a/config/callbacks/lr_monitor/default.yaml b/config/callbacks/lr_monitor/default.yaml new file mode 100644 index 0000000..542e572 --- /dev/null +++ b/config/callbacks/lr_monitor/default.yaml @@ -0,0 +1,4 @@ +# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html +_target_: pytorch_lightning.callbacks.LearningRateMonitor +logging_interval: step +log_momentum: False \ No newline at end of file diff --git a/config/callbacks/richsummary/default.yaml b/config/callbacks/richsummary/default.yaml new file mode 100644 index 0000000..87dbec4 --- /dev/null +++ b/config/callbacks/richsummary/default.yaml @@ -0,0 +1 @@ +_target_: pytorch_lightning.callbacks.RichModelSummary \ No newline at end of file diff --git a/config/dataloader/default.yaml b/config/dataloader/default.yaml new file mode 100644 index 0000000..3db269d --- /dev/null +++ b/config/dataloader/default.yaml @@ -0,0 +1,5 @@ +_target_: torch.utils.data.DataLoader +batch_size: 4 +num_workers: 2 +pin_memory: True +dataset: ${dataset} \ No newline at end of file diff --git a/config/logger/wandb.yaml b/config/logger/wandb.yaml new file mode 100644 index 0000000..227076b --- /dev/null +++ b/config/logger/wandb.yaml @@ -0,0 +1,4 @@ +_target_: pytorch_lightning.loggers.WandbLogger +project: brain-age +entity: 1pha +name: ${model.name} \ No newline at end of file diff --git a/config/logger/wandb_mask.yaml b/config/logger/wandb_mask.yaml new file mode 100644 index 0000000..7f6b879 --- /dev/null +++ b/config/logger/wandb_mask.yaml @@ -0,0 +1,7 @@ +_target_: pytorch_lightning.loggers.WandbLogger +project: brain-age +entity: 1pha +name: M ${dataset.mask_idx} | ${misc.seed} +tags: + - model=${model.name} + - MASK \ No newline at end of file diff --git a/config/metrics/binary.yaml b/config/metrics/binary.yaml new file mode 100644 index 0000000..b2bc57f --- /dev/null +++ b/config/metrics/binary.yaml @@ -0,0 +1,11 @@ +acc: + _target_: torchmetrics.Accuracy + task: binary + average: weighted +f1: + _target_: torchmetrics.F1Score + task: binary + average: macro +auroc: + _target_: torchmetrics.AUROC + task: binary \ No newline at end of file diff --git a/config/metrics/cls.yaml b/config/metrics/cls.yaml new file mode 100644 index 0000000..d58c903 --- /dev/null +++ b/config/metrics/cls.yaml @@ -0,0 +1,14 @@ +acc: + _target_: torchmetrics.Accuracy + task: multiclass + num_classes: ${model.backbone.num_classes} + average: weighted +f1: + _target_: torchmetrics.F1Score + task: multiclass + num_classes: ${model.backbone.num_classes} + average: macro +auroc: + _target_: torchmetrics.AUROC + task: multiclass + num_classes: ${model.backbone.num_classes} \ No newline at end of file diff --git a/config/metrics/reg.yaml b/config/metrics/reg.yaml new file mode 100644 index 0000000..0f0c973 --- /dev/null +++ b/config/metrics/reg.yaml @@ -0,0 +1,9 @@ +mae: + _target_: torchmetrics.MeanAbsoluteError +rmse: + _target_: torchmetrics.MeanSquaredError + squared: False +corr: + _target_: torchmetrics.PearsonCorrCoef +r2: + _target_: torchmetrics.R2Score \ No newline at end of file diff --git a/config/misc/default.yaml b/config/misc/default.yaml new file mode 100644 index 0000000..b41c965 --- /dev/null +++ b/config/misc/default.yaml @@ -0,0 +1,3 @@ +seed: 42 +debug: False +modes: [ train, valid ] diff --git a/config/module/default.yaml b/config/module/default.yaml new file mode 100644 index 0000000..2cdd451 --- /dev/null +++ b/config/module/default.yaml @@ -0,0 +1,9 @@ +_target_: sage.trainer.PLModule +_recursive_: False +augmentation: + spatial_size: [ 160 , 192 , 160 ] +load_from_checkpoint: # for lightning +load_model_ckpt: # For model checkpoint only +log_train_metrics: True +separate_lr: +save_dir: ${callbacks.checkpoint.dirpath} diff --git a/config/train.yaml b/config/train.yaml index 03b7586..a208c1f 100644 --- a/config/train.yaml +++ b/config/train.yaml @@ -10,104 +10,18 @@ defaults: - dataset: biobank - scheduler: cosine_anneal_warmup - optim: lion - -dataloader: - _target_: torch.utils.data.DataLoader - batch_size: 32 - num_workers: 8 - pin_memory: True - dataset: ${dataset} - -misc: - seed: 42 - debug: False - modes: [ train, valid, test ] - -module: - _target_: sage.trainer.PLModule - _recursive_: False - augmentation: - spatial_size: [ 160 , 192 , 160 ] - load_from_checkpoint: # for lightning - load_model_ckpt: # For model checkpoint only - separate_lr: - save_dir: ${callbacks.checkpoint.dirpath} - -metrics: - mae: - _target_: torchmetrics.MeanAbsoluteError - rmse: - _target_: torchmetrics.MeanSquaredError - squared: False - # corr: - # _target_: torchmetrics.PearsonCorrCoef - r2: - _target_: torchmetrics.R2Score - -logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: brain-age - entity: 1pha - name: R ${model.name} | ${misc.seed} - tags: - - model=${model.name} - - REG - -# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 500 - devices: 1 - accelerator: gpu - gradient_clip_val: 1 - log_every_n_steps: 50 - accumulate_grad_batches: 1 - # DEBUGGING FLAGS. TODO: Split - # limit_train_batches: 0.001 - # limit_val_batches: 0.01 + - trainer: default + - dataloader: default + - metrics: reg + - logger: wandb + - module: default + - callbacks/checkpoint: reg + - callbacks/early_stop: reg + - callbacks/lr_monitor: default + - callbacks/richsummary: default + - misc: default callbacks: - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html - checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${hydra:run.dir} - filename: "{step}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}" - monitor: epoch/valid_MeanAbsoluteError - mode: min - save_top_k: 1 - save_last: True - # Is useful to set it to False when metric names contain / as this will result in extra folders - auto_insert_metric_name: False - - # checkpoint_lr: - # _target_: pytorch_lightning.callbacks.ModelCheckpoint - # dirpath: ${hydra:run.dir} - # filename: "ckpt-step{step}-lr{_lr:.2e}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}" - # monitor: _lr - # save_top_k: 7 - # mode: min - # auto_insert_metric_name: False - - # https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html - early_stop: - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: epoch/valid_MeanSquaredError - mode: min - min_delta: 0.005 - patience: 20 - - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html - lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: False - - richsummary: - _target_: pytorch_lightning.callbacks.RichModelSummary - - richpbar: - _target_: pytorch_lightning.callbacks.RichProgressBar - manual_ckpt: _target_: sage.trainer.callbacks.AnchorCheckpoint multiplier: 1 \ No newline at end of file diff --git a/config/train_binary.yaml b/config/train_binary.yaml index 28a87c3..0bd5507 100644 --- a/config/train_binary.yaml +++ b/config/train_binary.yaml @@ -12,90 +12,13 @@ defaults: - dataset: ppmi - scheduler: exp_decay - optim: adamw - -dataloader: - _target_: torch.utils.data.DataLoader - batch_size: 4 - num_workers: 2 - pin_memory: True - dataset: ${dataset} - -misc: - seed: 42 - debug: False - modes: [ train, valid ] - -module: - _target_: sage.trainer.PLModule - _recursive_: False - augmentation: - spatial_size: [ 160 , 192 , 160 ] - load_from_checkpoint: # for lightning - load_model_ckpt: # For model checkpoint only - log_train_metrics: True - separate_lr: - save_dir: ${callbacks.checkpoint.dirpath} - -metrics: - acc: - _target_: torchmetrics.Accuracy - task: binary - average: weighted - f1: - _target_: torchmetrics.F1Score - task: binary - average: macro - auroc: - _target_: torchmetrics.AUROC - task: binary - -logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: brain-age - entity: 1pha - name: C ${model.name} - -# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 300 - devices: 1 - accelerator: gpu - gradient_clip_val: 1 - log_every_n_steps: 50 - accumulate_grad_batches: 4 - # DEBUGGING FLAGS. TODO: Split - # limit_train_batches: 0.001 - # limit_val_batches: 0.01 - -callbacks: - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html - checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${hydra.run.dir} - filename: "{step}-valid_f1-{epoch/valid_BinaryF1Score:.3f}" - monitor: epoch/valid_BinaryF1Score - mode: max - save_top_k: 1 - save_last: True - # Is useful to set it to False when metric names contain / as this will result in extra folders - auto_insert_metric_name: False - - # https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html - early_stop: - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: epoch/valid_BinaryF1Score - mode: max - patience: 10 - - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html - lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: False - - richsummary: - _target_: pytorch_lightning.callbacks.RichModelSummary - - richpbar: - _target_: pytorch_lightning.callbacks.RichProgressBar \ No newline at end of file + - trainer: default + - dataloader: default + - metrics: binary + - logger: wandb + - module: default + - callbacks/checkpoint: binary + - callbacks/early_stop: binary + - callbacks/lr_monitor: default + - callbacks/richsummary: default + - misc: default diff --git a/config/train_cls.yaml b/config/train_cls.yaml index f14abf1..310bf87 100644 --- a/config/train_cls.yaml +++ b/config/train_cls.yaml @@ -12,93 +12,13 @@ defaults: - dataset: ppmi - scheduler: exp_decay - optim: adamw - -dataloader: - _target_: torch.utils.data.DataLoader - batch_size: 4 - num_workers: 2 - pin_memory: True - dataset: ${dataset} - -misc: - seed: 42 - debug: False - modes: [ train, valid ] - -module: - _target_: sage.trainer.PLModule - _recursive_: False - augmentation: - spatial_size: [ 160 , 192 , 160 ] - load_from_checkpoint: # for lightning - load_model_ckpt: # For model checkpoint only - log_train_metrics: True - separate_lr: - save_dir: ${callbacks.checkpoint.dirpath} - -metrics: - acc: - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: ${model.backbone.num_classes} - average: weighted - f1: - _target_: torchmetrics.F1Score - task: multiclass - num_classes: ${model.backbone.num_classes} - average: macro - auroc: - _target_: torchmetrics.AUROC - task: multiclass - num_classes: ${model.backbone.num_classes} - -logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: brain-age - entity: 1pha - name: C ${model.name} - -# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 300 - devices: 1 - accelerator: gpu - gradient_clip_val: 1 - log_every_n_steps: 50 - accumulate_grad_batches: 4 - # DEBUGGING FLAGS. TODO: Split - # limit_train_batches: 0.001 - # limit_val_batches: 0.01 - -callbacks: - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html - checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${hydra.run.dir} - filename: "{step}-valid_f1-{epoch/valid_MulticlassF1Score:.3f}" - monitor: epoch/valid_MulticlassF1Score - mode: max - save_top_k: 1 - save_last: True - # Is useful to set it to False when metric names contain / as this will result in extra folders - auto_insert_metric_name: False - - # https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html - early_stop: - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: epoch/valid_MulticlassF1Score - mode: max - patience: 10 - - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html - lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: False - - richsummary: - _target_: pytorch_lightning.callbacks.RichModelSummary - - richpbar: - _target_: pytorch_lightning.callbacks.RichProgressBar \ No newline at end of file + - trainer: default + - dataloader: default + - metrics: cls + - logger: wandb + - module: default + - callbacks/checkpoint: cls + - callbacks/early_stop: cls + - callbacks/lr_monitor: default + - callbacks/richsummary: default + - misc: default diff --git a/config/train_mask.yaml b/config/train_mask.yaml index 3a4ebaf..898aba4 100644 --- a/config/train_mask.yaml +++ b/config/train_mask.yaml @@ -10,104 +10,18 @@ defaults: - dataset: biobank_mask - scheduler: cosine_anneal_warmup - optim: lion - -dataloader: - _target_: torch.utils.data.DataLoader - batch_size: 32 - num_workers: 4 - pin_memory: True - dataset: ${dataset} - -misc: - seed: 42 - debug: False - modes: [ train, valid, test ] - -module: - _target_: sage.trainer.PLModule - _recursive_: False - augmentation: - spatial_size: [ 160 , 192 , 160 ] - load_from_checkpoint: # for lightning - load_model_ckpt: # For model checkpoint only - separate_lr: - save_dir: ${callbacks.checkpoint.dirpath} - -metrics: - mae: - _target_: torchmetrics.MeanAbsoluteError - rmse: - _target_: torchmetrics.MeanSquaredError - squared: False - # corr: - # _target_: torchmetrics.PearsonCorrCoef - r2: - _target_: torchmetrics.R2Score - -logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: brain-age - entity: 1pha - name: M ${dataset.mask_idx} | ${misc.seed} - tags: - - model=${model.name} - - MASK - -# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 500 - devices: 1 - accelerator: gpu - gradient_clip_val: 1 - log_every_n_steps: 50 - accumulate_grad_batches: 1 - # DEBUGGING FLAGS. TODO: Split - # limit_train_batches: 0.001 - # limit_val_batches: 0.01 + - trainer: default + - dataloader: default + - metrics: reg + - logger: wandb_mask + - module: default + - callbacks/checkpoint: reg + - callbacks/early_stop: reg + - callbacks/lr_monitor: default + - callbacks/richsummary: default + - misc: default callbacks: - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html - checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${hydra:run.dir} - filename: "{step}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}" - monitor: epoch/valid_MeanAbsoluteError - mode: min - save_top_k: 1 - save_last: True - # Is useful to set it to False when metric names contain / as this will result in extra folders - auto_insert_metric_name: False - - # checkpoint_lr: - # _target_: pytorch_lightning.callbacks.ModelCheckpoint - # dirpath: ${hydra:run.dir} - # filename: "ckpt-step{step}-lr{_lr:.2e}-valid_mae{epoch/valid_MeanAbsoluteError:.3f}" - # monitor: _lr - # save_top_k: 7 - # mode: min - # auto_insert_metric_name: False - - # https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html - early_stop: - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: epoch/valid_MeanSquaredError - mode: min - min_delta: 0.005 - patience: 20 - - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html - lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: False - - richsummary: - _target_: pytorch_lightning.callbacks.RichModelSummary - - richpbar: - _target_: pytorch_lightning.callbacks.RichProgressBar - manual_ckpt: _target_: sage.trainer.callbacks.AnchorCheckpoint multiplier: 1 \ No newline at end of file diff --git a/config/trainer/default.yaml b/config/trainer/default.yaml index 2239c9c..aa102e9 100644 --- a/config/trainer/default.yaml +++ b/config/trainer/default.yaml @@ -1,3 +1,4 @@ +# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html _target_: pytorch_lightning.Trainer max_epochs: 300 devices: 1