Skip to content

Commit

Permalink
Fix PPMI dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent 30a7e99 commit 407643b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 54 deletions.
4 changes: 1 addition & 3 deletions config/dataset/ppmi.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
_target_: sage.data.mask.UKB_MaskDataset
root:
label_name:
_target_: sage.data.ppmi.PPMIClassification
mode: train
valid_ratio: 0.1
seed: ${misc.seed}
48 changes: 20 additions & 28 deletions config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,52 @@ hydra:

defaults:
- _self_
- model: resnet_cls
- dataset: biobank_cls
- scheduler: onecycle
- model: convnext_cls
- dataset: ppmi
- scheduler: cosine_anneal_warmup
- optim: lion

dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 64
num_workers: 4
batch_size: 4
num_workers: 2
pin_memory: True
dataset: ${dataset}

optim:
_target_: torch.optim.AdamW
lr: 1e-4
weight_decay: 1e-6
eps: 1e-8

misc:
seed: 42
debug: False
modes: [ train, valid, test ]
modes: [ train, valid ]

module:
_target_: sage.trainer.PLModule
_recursive_: False
augmentation:
_target_: sage.data.augment
spatial_size: [ 96, 96, 96 ]
mask: False
mask_threshold: 0.1
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}

metrics:
acc:
_target_: torchmetrics.Accuracy
task: binary
task: multiclass
num_classes: ${model.backbone.num_classes}
f1:
_target_: torchmetrics.F1Score
task: binary
task: multiclass
num_classes: ${model.backbone.num_classes}
auroc:
_target_: torchmetrics.AUROC
task: binary
task: multiclass
num_classes: ${model.backbone.num_classes}

logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: brain-age
entity: 1pha
name: C ${model.name} | mask=${module.mask}
tags:
- model=${model.name}
- mask=${module.mask}
- CLS
name: C ${model.name}

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer:
Expand All @@ -69,7 +61,7 @@ trainer:
accelerator: gpu
gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 1
accumulate_grad_batches: 4
# DEBUGGING FLAGS. TODO: Split
# limit_train_batches: 0.001
# limit_val_batches: 0.01
Expand All @@ -79,8 +71,8 @@ callbacks:
checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra:run.dir}
filename: "{step}-valid_f1-{epoch/valid_BinaryF1Score:.3f}"
monitor: epoch/valid_BinaryF1Score
filename: "{step}-valid_f1-{epoch/valid_MulticlassF1Score:.3f}"
monitor: epoch/valid_MulticlassF1Score
mode: max
save_top_k: 1
save_last: True
Expand All @@ -90,7 +82,7 @@ callbacks:
# https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html
early_stop:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: epoch/valid_BinaryF1Score
monitor: epoch/valid_MulticlassF1Score
mode: max
patience: 10

Expand Down
39 changes: 24 additions & 15 deletions sage/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def open_scan(fname: str) -> Tuple[np.array, dict]:
arr = open_npy(fname)
elif suffix == ".h5":
arr, meta = open_h5(fname=fname)
elif suffix in {".nii", ".nii.gz"}:
elif suffix in {".nii", ".nii.gz"} or str(fname).endswith(".nii.gz"):
arr = nib.load(filename=fname)
arr = arr.get_fdata()
return arr, meta
Expand Down Expand Up @@ -70,7 +70,7 @@ def open_h5_nifti(fname: str) -> nib.nifti1.Nifti1Image:
nii = nib.nifti1.Nifti1Image(dataobj=arr, affine=affine)
return nii


class DatasetBase(Dataset):
""" This Dataset class takes `.csv` labels with following scheme
cols: pid (primary_key) | label | abspath
Expand All @@ -92,6 +92,8 @@ def __init__(self,
pk_col: str,
pid_col: str,
label_col: str,
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "exclusion.csv",
augmentation: str = "monai",
seed: int = 42,):
Expand All @@ -103,10 +105,12 @@ def __init__(self,
root = Path(root)
labels: pd.DataFrame = self.load_labels(root=root, label_name=label_name, mode=mode)
labels: pd.DataFrame = self.remove_duplicates(labels=labels)
if mod_col and modality:
labels: pd.DataFrame = self.filter_data(labels=labels, col=mod_col, leave=modality)
self.mode = mode

if mode != "test":
labels: pd.DataFrame = self._split_data(labels=labels, valid_ratio=valid_ratio,
labels: pd.DataFrame = self.split_data(labels=labels, valid_ratio=valid_ratio,
pid_col=pid_col, mode=mode, seed=seed)

self.sanity_check(labels=labels, path_col=path_col)
Expand All @@ -127,13 +131,18 @@ def _exclude_data(self,
root: Path,
exclusion_fname: str = "exclusion.csv") -> List[Path]:
return lst

def filter_data(self, labels: pd.DataFrame, col: str, leave: List[str]) -> pd.DataFrame:
cond = labels[col].isin(set(leave))
labels = labels[cond]
return labels

def _split_data(self,
labels: pd.DataFrame,
valid_ratio: float = 0.1,
pid_col: str = "",
mode: str = "train",
seed: int = 42) -> pd.DataFrame:
def split_data(self,
labels: pd.DataFrame,
valid_ratio: float = 0.1,
pid_col: str = "",
mode: str = "train",
seed: int = 42) -> pd.DataFrame:
# Data split, used fixated seed
if pid_col:
pid = labels[pid_col].unique().tolist()
Expand Down Expand Up @@ -320,15 +329,15 @@ def _age_filter(self, files: list) -> list:
logger.info("#%s scans were excluded since they were not found as h5 files in biobank", len(passed))
return files

def _split_data(self,
files: list,
valid_ratio: float = 0.1,
mode: str = "train",
seed: int = 42) -> pd.DataFrame:
def split_data(self,
files: list,
valid_ratio: float = 0.1,
mode: str = "train",
seed: int = 42) -> pd.DataFrame:
""" Override function.
Filters out data with age. """
files = self._age_filter(files=files)
files = super()._split_data(files=files, valid_ratio=valid_ratio, mode=mode, seed=seed)
files = super().split_data(files=files, valid_ratio=valid_ratio, mode=mode, seed=seed)
return files

def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
Expand Down
42 changes: 34 additions & 8 deletions sage/data/ppmi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from pathlib import Path
from typing import Tuple
from typing import Tuple, List

import torch
import pandas as pd
from sklearn.model_selection import train_test_split

from sage.data.dataloader import DatasetBase, open_scan
import sage.constants as C
Expand All @@ -13,9 +12,37 @@
logger = get_logger(name=__name__)


class PPMIClassification(DatasetBase):
NAME = "PPMI_CLS"
class PPMIBase(DatasetBase):
NAME = "PPMI"
MAPPER2INT = dict(Control=0, PD=1, SWEDD=2, Prodromal=3)
MOD_MAPPER = dict(t1="T1-anatomical", t2="T2 in T1-anatomical space")
def __init__(self,
root: Path | str = C.PPMI_DIR,
label_name: str = "ppmi_label.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "abs_path",
pk_col: str = "Image Data ID",
pid_col: str = "Subject",
label_col: str = "Group",
mod_col: str = "Description",
modality: List[str] = ["t1"],
exclusion_fname: str = "exclusion.csv",
augmentation: str = "monai",
seed: int = 42,):
modality = [self.MOD_MAPPER[m] for m in modality]
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
path_col=path_col, pk_col=pk_col, pid_col=pid_col, label_col=label_col,
mod_col=mod_col, modality=modality, exclusion_fname=exclusion_fname,
augmentation=augmentation, seed=seed)

def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
""" Make sure to properly return PPMI """
raise NotImplementedError


class PPMIClassification(PPMIBase):
NAME = "PPMI-CLS"
def __init__(self,
root: Path | str = C.PPMI_DIR,
label_name: str = "ppmi_label.csv",
Expand All @@ -31,7 +58,7 @@ def __init__(self,
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
path_col=path_col, pk_col=pk_col, pid_col=pid_col, label_col=label_col,
exclusion_fname=exclusion_fname, augmentation=augmentation, seed=seed)

def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
data: dict = self.labels.iloc[idx].to_dict()
arr, _ = open_scan(data[self.path_col])
Expand All @@ -46,9 +73,8 @@ def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
return arr, label


class PPMIAgeRegression(DatasetBase):
class PPMIAgeRegression(PPMIBase):
NAME = "PPMI_AGE"
MAPPER2INT = dict(Control=0, PD=1, SWEDD=2, Prodromal=3)
def __init__(self,
root: Path | str = C.PPMI_DIR,
label_name: str = "ppmi_label.csv",
Expand All @@ -64,7 +90,7 @@ def __init__(self,
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
path_col=path_col, pk_col=pk_col, pid_col=pid_col, label_col=label_col,
exclusion_fname=exclusion_fname, augmentation=augmentation, seed=seed)

def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
data: dict = self.labels.iloc[idx].to_dict()
arr, _ = open_scan(data[self.path_col])
Expand Down

0 comments on commit 407643b

Please sign in to comment.