Skip to content

Commit

Permalink
Move notebook files and shell scripts into subdirectories
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Dec 21, 2023
1 parent 2e1ffba commit 3373156
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 246 deletions.
235 changes: 138 additions & 97 deletions RQ/q0_preliminary.ipynb

Large diffs are not rendered by default.

16 changes: 0 additions & 16 deletions batch_infer.sh

This file was deleted.

48 changes: 0 additions & 48 deletions batch_infer2.sh

This file was deleted.

76 changes: 0 additions & 76 deletions infer_ckpt.py

This file was deleted.

72 changes: 66 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,76 @@
import os
import argparse
from pathlib import Path

import hydra
import omegaconf

import sage


logger = sage.utils.get_logger(name=__name__)


@hydra.main(config_path="config", config_name="train.yaml", version_base="1.1")
def main(config: omegaconf.DictConfig):
logger.info("Start Training")
sage.trainer.inference(config)
MASK_DIR = Path("assets/masks")


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--path", type=str, help="Leaf node directory name. e.g. resnet10t-mask")
parser.add_argument("--root", default="meta_brain/weights/default/", type=str, help="Root directory where weights resides")

parser.add_argument("--mask", type=str, default=False, help="Masking inference")

parser.add_argument("--batch_size", type=int, default=1, help="batch size during inference")

parser.add_argument("--infer_xai", type=str, default="False", help="Infer xai or not")
parser.add_argument("--top_k", type=float, default=0.99, help="")
parser.add_argument("--xai_method", type=str, default="gbp", help="Which explainability method to use")
parser.add_argument("--baseline", type=bool, default=False, help="Baseline brain for Integrated gradients")

args = parser.parse_args()
return args


def main(args):
root = Path(args.root) / args.path
# Starting with numbers is the checkpoint recorded by best monitoring checkpoint via save_top_k=1
weight = sorted(root.glob("*.ckpt"))[0]

mask = sage.utils.parse_bool(args.mask)
overrides = ["misc.modes=[train,valid,test]",
f"module.load_model_ckpt={weight}",
f"dataloader.batch_size={args.batch_size}"]
# f"module.mask={MASK_DIR/mask if mask else 'False'}"]

infer_xai: bool = sage.utils.parse_bool(args.infer_xai)
if infer_xai:
logger.info("Infer XAI map")
overrides += [
"+module.target_layer_index=-1",
"module._target_=sage.xai.trainer.XPLModule",
f"+module.top_k_percentile={args.top_k}",
f"+module.xai_method={args.xai_method}",
"+trainer.inference_mode=False",
"trainer.accelerator=gpu"
]
if args.xai_method == "ig":
overrides += [f"+module.baseline={sage.utils.parse_bool(args.baseline)}"]
else:
logger.info("Infer Metrics")

with hydra.initialize(config_path=str(root / ".hydra"), version_base="1.1"):
config = hydra.compose(config_name="config.yaml", overrides=overrides, return_hydra_config=True)
# TODO: Hydra key-interpolation did not work
for callback in config.callbacks:
_cb = config.callbacks[callback]
_cb.update({"dirpath": root} if "dirpath" in _cb else {})

logger.info("Start Inference")
os.makedirs(root, exist_ok=True)
sage.trainer.inference(config, root_dir=root)


if __name__=="__main__":
main()
args = parse_args()
main(args)
5 changes: 5 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Analysis notebooks

Archive of notebooks used for sanity checking.

Notebooks inside `tmp` directory are the ones that are undone.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion project_mask.ipynb → notebooks/tmp/project_mask.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(self,
task: str = "reg"):
super().__init__()
logger.info("Start Initiating model %s", name.upper())
# self.backbone = torch.compile(backbone)
self.backbone = backbone
self.backbone = torch.compile(self.backbone, dynamic=True)
self.criterion = criterion
self.NAME = name
self.TASK = task
Expand Down
2 changes: 1 addition & 1 deletion sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def finalize_inference(prediction: list,
if name.startswith("C"):
logger.info("Classification data given:")
_cls_infrence(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
elif name.startswith("R"):
elif name[0] in set("R", "M"):
logger.info("Regression data given:")
_reg_infrence(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
else:
Expand Down

0 comments on commit 3373156

Please sign in to comment.