Skip to content

Commit

Permalink
Add ADNi to codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed May 9, 2024
1 parent b175a9f commit 58bbe60
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 26 deletions.
4 changes: 4 additions & 0 deletions config/dataset/adni_tertiary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: sage.data.adni.ADNIClassification
mode: train
valid_ratio: 0.1
seed: ${misc.seed}
2 changes: 0 additions & 2 deletions config/sweep/adni_sweep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ metric:
goal: maximize
name: test_acc
parameters:
model:
values: [ resnet_binary , convnext_binary ]
optim:
values: [ adamw , lion ]
scheduler:
Expand Down
2 changes: 1 addition & 1 deletion config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ hydra:
defaults:
- _self_
- model: convnext_cls
- dataset: ppmi
- dataset: adni
- scheduler: exp_decay
- optim: adamw
- trainer: default
Expand Down
2 changes: 1 addition & 1 deletion sage/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@

EXT_BASE = DATA_BASE / "brain"
PPMI_DIR = EXT_BASE / "PPMI"
ADNI_DIR = EXT_BASE / "ADNI"
ADNI_DIR = Path("adni") / "ADNI_3_reg"
61 changes: 42 additions & 19 deletions sage/data/adni.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@

class ADNIBase(DatasetBase):
NAME = "ADNI"
MAPPER2INT = {"ADNI 2": 0, "ADNI 3": 1}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_label.csv",
label_name: str = "adni_labels_240509.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "abs_path",
pk_col: str = "Subject ID",
pid_col: str = "Subject ID",
label_col: str = "Phase",
strat_col: str = "Phase",
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "donotuse-adni.txt",
exclusion_fname: str = "",
augmentation: str = "monai",
seed: int = 42,):
logger.warn("Please note that ADNI dataset label file should not have the exclusion file.")
Expand Down Expand Up @@ -58,20 +57,20 @@ def _exclude_data(self, labels: pd.DataFrame, pk_col: str, root: Path,

class ADNIClassification(ADNIBase):
NAME = "ADNI-CLS"
MAPPER2INT = {"ADNI 2": 0, "ADNI 3": 1}
MAPPER2INT = {"CN": 0, "MCI": 1, "AD": 2}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_label.csv",
label_name: str = "adni_labels_240509.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "abs_path",
pk_col: str = "Subject ID",
pid_col: str = "Subject ID",
label_col: str = "Phase",
strat_col: str = "Phase",
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "donotuse-adni.txt",
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
mod_col: str = "Group",
modality: List[str] = ["CN", "MCI", "AD"],
exclusion_fname: str = "",
augmentation: str = "monai",
seed: int = 42,):
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
Expand All @@ -90,4 +89,28 @@ def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
logger.warn("Wrong label: %s\nData:%s", data[self.label_col], arr)
raise
label = torch.tensor(label, dtype=torch.long)
return arr, label
return arr, label


class ADNIFullClassification(ADNIClassification):
NAME = "ADNI-ALL-CLS"
MAPPER2INT = {"CN": 0, "SMC": 1, "EMCI": 2, "MCI": 3, "LMCI": 4, "AD": 5}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_labels_240509.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "",
augmentation: str = "monai",
seed: int = 42,):
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
path_col=path_col, pk_col=pk_col, pid_col=pid_col, label_col=label_col,
strat_col=strat_col, mod_col=mod_col, modality=modality,
exclusion_fname=exclusion_fname, augmentation=augmentation, seed=seed)
5 changes: 4 additions & 1 deletion sage/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __init__(self,
def load_labels(self, root: Path, label_name: str = None, mode: str = None) -> pd.DataFrame:
""" Load `.csv` """
labels = pd.read_csv(root / label_name)
if mode is not None and "split" in labels.columns:
mode = "train" if mode != "test" else "test"
labels = labels[labels.split == mode]
return labels

def remove_duplicates(self, labels: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -156,7 +159,7 @@ def split_data(self,
trn_pid, val_pid = train_test_split(pid, test_size=valid_ratio, random_state=seed, stratify=y)
trn = labels[labels[pid_col].isin(trn_pid)]
val = labels[labels[pid_col].isin(val_pid)]
else:
else:
trn, val = train_test_split(labels, test_size=valid_ratio, random_state=seed)
labels = {"train": trn, "valid": val}.get(mode, None)
if labels is None:
Expand Down
4 changes: 2 additions & 2 deletions sweep_command.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ sweep_adni() {
echo "Sweep on ADNI"
python sweep.py --sweep_cfg_name=adni_sweep.yaml\
--wandb_project=adni\
--config_name=train_binary.yaml\
--config_name=train_cls.yaml\
--sweep_prefix='Scratch'\
--overrides="['dataset=adni']"
--overrides="['dataloader.batch_size=8']"
}

# Check the input argument and call the appropriate function
Expand Down

0 comments on commit 58bbe60

Please sign in to comment.