Skip to content

Commit

Permalink
Minor fixes. Add inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Nov 8, 2024
1 parent 096154f commit 2d725d0
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 79 deletions.
23 changes: 20 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import argparse
from pathlib import Path

Expand All @@ -17,6 +18,7 @@ def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--path", type=str, help="Leaf node directory name. e.g. resnet10t-mask")
parser.add_argument("--ckpt_step", default=None, type=int, help="Finds checkpoint step.")
parser.add_argument("--root", default="meta_brain/weights/default/", type=str, help="Root directory where weights resides")

parser.add_argument("--batch_size", type=int, default=1, help="batch size during inference")
Expand All @@ -33,8 +35,23 @@ def parse_args():
def main(args):
root = Path(args.root) / args.path
# Starting with numbers is the checkpoint recorded by best monitoring checkpoint via save_top_k=1
weight = sorted(root.glob("*.ckpt"))[0]

ckpts = sorted(root.glob("*.ckpt"))
if args.ckpt_step is None:
weight = ckpts[0]
else:
ckpts_step = [ckpt.stem for ckpt in ckpts]
ckpt_idx = [idx for idx, ckpt in enumerate(ckpts_step)
if (ckpt.startswith("step") and int(ckpt.split("-")[0][4:]) == args.ckpt_step)]
if len(ckpt_idx):
# Yes there is a finding ckpt_step
weight = ckpts[ckpt_idx[0]]
else:
# No step of checkpoint looking for.
logger.info("No step of checkpoint you are looking for %s", args.ckpt_step)
logger.info("Weight list: %s", ckpts)
# 안되는건 그냥 나중에 step이름 쑤셔넣는거로 대체
sys.exit()

overrides = ["misc.modes=[train,valid,test]",
f"module.load_model_ckpt={weight}",
f"dataloader.batch_size={args.batch_size}"]
Expand Down Expand Up @@ -65,7 +82,7 @@ def main(args):

logger.info("Start Inference")
os.makedirs(root, exist_ok=True)
sage.trainer.inference(config, root_dir=root)
sage.trainer.inference(config, root_dir=root, ckpt_step=args.ckpt_step)


if __name__=="__main__":
Expand Down
12 changes: 7 additions & 5 deletions sage/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
# Reshape target size
SPATIAL_SIZE = (160, 192, 160)

DATA_BASE = Path(config.get("DATA_BASE", Path.home() / "data" / "hdd01" / "1pha"))
BIOBANK_PATH = DATA_BASE / "h5"
DATA_BASE1 = Path(config.get("DATA_BASE", Path.home() / "data" / "hdd01" / "1pha"))
BIOBANK_PATH = DATA_BASE1 / "h5"
ADNI_DIR = DATA_BASE1 / "brain" / "ADNI_08_12_2024" / "3_reg"

EXT_BASE = DATA_BASE / "brain"
PPMI_DIR = EXT_BASE / "PPMI"
ADNI_DIR = Path("adni") / "ADNI_3_reg"
DATA_BASE3 = Path.home() / "data" / "hdd03" / "1pha"
PPMI_DIR = DATA_BASE3 / "ppmi" / "PPMI_4_reg"
# PPMI_DIR = Path("~/workspace/brain-age-prediction/ppmi/PPMI_4_reg")
ADNI_DIR = Path("adni")
44 changes: 34 additions & 10 deletions sage/data/adni.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class ADNIBase(DatasetBase):
NAME = "ADNI"
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_labels_240509.csv",
label_name: str = "adni_screen_labels_Sept11_test15_2024.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
label_col: str = "DX_bl",
strat_col: str = "DX_bl",
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "",
Expand Down Expand Up @@ -60,15 +60,15 @@ class ADNIClassification(ADNIBase):
MAPPER2INT = {"CN": 0, "MCI": 1, "AD": 2}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_labels_240509.csv",
label_name: str = "adni_screen_labels_Sept11_test15_2024.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
mod_col: str = "Group",
label_col: str = "DX_bl",
strat_col: str = "DX_bl",
mod_col: str = "DX_bl",
modality: List[str] = ["CN", "MCI", "AD"],
exclusion_fname: str = "",
augmentation: str = "monai",
Expand All @@ -92,19 +92,43 @@ def _load_data(self, idx: int) -> Tuple[torch.Tensor]:
return arr, label


class ADNIBinary(ADNIClassification):
NAME = "ADNI-Binary"
MAPPER2INT = {"CN": 0, "AD": 1}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_screen_labels_Sept11_test15_2024.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "DX_bl",
strat_col: str = "DX_bl",
mod_col: str = "DX_bl",
modality: List[str] = ["CN", "AD"],
exclusion_fname: str = "",
augmentation: str = "monai",
seed: int = 42,):
super().__init__(root=root, label_name=label_name, mode=mode, valid_ratio=valid_ratio,
path_col=path_col, pk_col=pk_col, pid_col=pid_col, label_col=label_col,
strat_col=strat_col, mod_col=mod_col, modality=modality,
exclusion_fname=exclusion_fname, augmentation=augmentation, seed=seed)


class ADNIFullClassification(ADNIClassification):
NAME = "ADNI-ALL-CLS"
MAPPER2INT = {"CN": 0, "SMC": 1, "EMCI": 2, "MCI": 3, "LMCI": 4, "AD": 5}
def __init__(self,
root: Path | str = C.ADNI_DIR,
label_name: str = "adni_labels_240509.csv",
label_name: str = "adni_screen_labels_Sept11_test15_2024.csv",
mode: str = "train",
valid_ratio: float = .1,
path_col: str = "filepath",
pk_col: str = "Subject",
pid_col: str = "Subject",
label_col: str = "Group",
strat_col: str = "Group",
label_col: str = "DX_bl",
strat_col: str = "DX_bl",
mod_col: str = None,
modality: List[str] = None,
exclusion_fname: str = "",
Expand Down
Loading

0 comments on commit 2d725d0

Please sign in to comment.