-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add occlusion training warpup code to notebooks
- Loading branch information
Showing
1 changed file
with
359 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,359 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Wrap-up codes for Occlusion Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"83" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from collections import defaultdict\n", | ||
"\n", | ||
"import wandb\n", | ||
"\n", | ||
"# Open target RoI indices\n", | ||
"with open(\"assets/dkt_indices.txt\", mode=\"r\") as f:\n", | ||
" gt_rois = {int(roi.strip()) for roi in f.readlines()}\n", | ||
"\n", | ||
"api = wandb.Api()\n", | ||
"runs = api.runs(path=\"1pha/brain-age\",\n", | ||
" filters={\"config.dataloader.dataset._target_\": \"sage.data.mask.UKB_MaskDataset\"})\n", | ||
"len(runs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"runs_dict = defaultdict(dict)\n", | ||
"runs_idx = defaultdict(set)\n", | ||
"for run in runs:\n", | ||
" state = run.state\n", | ||
" name: str = run.name.split(\" \")[1]\n", | ||
" if not name.isnumeric():\n", | ||
" # Skip outlier run name that does not follow `M XXXX | SEED`\n", | ||
" continue\n", | ||
" idx = int(name)\n", | ||
" if state == \"finished\":\n", | ||
" runs_dict[\"success\"][idx] = run\n", | ||
" runs_idx[\"success\"].add(idx)\n", | ||
" elif state == \"failed\":\n", | ||
" runs_dict[\"fail\"][idx] = run\n", | ||
" runs_idx[\"fail\"].add(idx)\n", | ||
" else:\n", | ||
" runs_dict[\"etc\"][idx] = (run, state)\n", | ||
" runs_idx[\"etc\"].add(idx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(66, 11, 3)" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"len(runs_idx[\"success\"]), len(runs_idx[\"fail\"]), len(runs_idx[\"etc\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"assert gt_rois.issuperset(runs_idx[\"success\"])\n", | ||
"leftover = gt_rois - runs_idx[\"success\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"34" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"len(leftover)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Actual failed runs: 9\n", | ||
"1025\n", | ||
"2018\n", | ||
"2024\n", | ||
"1002\n", | ||
"77\n", | ||
"2030\n", | ||
"1009\n", | ||
"54\n", | ||
"24\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"true_fail = {idx for idx in runs_idx[\"fail\"] if idx not in runs_idx[\"success\"]}\n", | ||
"print(f\"Actual failed runs: {len(true_fail)}\")\n", | ||
"for fail in true_fail:\n", | ||
" print(fail)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Runs checked manually that has final test performance \n", | ||
"manual_check_success: set = {1025, 2018, 2024, 1002, 77, 54, 24}\n", | ||
"\n", | ||
"# Was not able to get test performance but has final checkpoint\n", | ||
"manual_check_ckpt: set = {2030, 1009}\n", | ||
"\n", | ||
"# Sanity check for empty runs\n", | ||
"assert len(true_fail - (manual_check_ckpt | manual_check_success)) == 0\n", | ||
"assert len((manual_check_ckpt | manual_check_success) - true_fail) == 0" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"25" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"leftover = leftover - true_fail\n", | ||
"len(leftover)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{2017}" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"leftover & runs_idx[\"etc\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Mask idx 2017 turned out to be crashed. Crahsed runs should be re-batched." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{2,\n", | ||
" 46,\n", | ||
" 47,\n", | ||
" 49,\n", | ||
" 50,\n", | ||
" 251,\n", | ||
" 252,\n", | ||
" 253,\n", | ||
" 254,\n", | ||
" 255,\n", | ||
" 1010,\n", | ||
" 1011,\n", | ||
" 1012,\n", | ||
" 1013,\n", | ||
" 1014,\n", | ||
" 1015,\n", | ||
" 1016,\n", | ||
" 2005,\n", | ||
" 2006,\n", | ||
" 2007,\n", | ||
" 2008,\n", | ||
" 2017,\n", | ||
" 2031,\n", | ||
" 2034,\n", | ||
" 2035}" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"leftover" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Split runs and re-batch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"\n", | ||
"def split_into_chunks(data_set, num_chunks):\n", | ||
" # Convert set to list for shuffling\n", | ||
" data_list = list(data_set)\n", | ||
" \n", | ||
" # Calculate the size of each chunk\n", | ||
" chunk_size = len(data_list) // num_chunks\n", | ||
" remainder = len(data_list) % num_chunks\n", | ||
" \n", | ||
" # Split the list into chunks\n", | ||
" chunks = []\n", | ||
" start = 0\n", | ||
" for i in range(num_chunks):\n", | ||
" chunk_length = chunk_size + (1 if i < remainder else 0)\n", | ||
" chunks.append(data_list[start:start + chunk_length])\n", | ||
" start += chunk_length\n", | ||
" \n", | ||
" flattened_chunks = [item for sublist in chunks for item in sublist]\n", | ||
" assert set(flattened_chunks) == data_set, \"Chunks do not contain all elements of the data set\"\n", | ||
" return chunks\n", | ||
"\n", | ||
"\n", | ||
"machines = [\"185-0\", \"185-1\", \"245-0\", \"245-1\"]\n", | ||
"chunks = split_into_chunks(data_set=leftover, num_chunks=4)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[[2, 46, 47, 49, 50, 2005, 2006],\n", | ||
" [2007, 2008, 2017, 2031, 2034, 2035],\n", | ||
" [1010, 1011, 1012, 1013, 1016, 1014],\n", | ||
" [1015, 251, 252, 253, 254, 255]]" | ||
] | ||
}, | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"chunks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for chunk, machine in zip(chunks, machines):\n", | ||
" with open(f\"assets/dkt_leftover_{machine}.txt\", mode=\"w\") as f:\n", | ||
" for roi in chunk:\n", | ||
" f.write(f\"{roi}\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "age", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |