Skip to content

Commit

Permalink
Add occlusion training warpup code to notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 18, 2024
1 parent c2abd9a commit b5cf589
Showing 1 changed file with 359 additions and 0 deletions.
359 changes: 359 additions & 0 deletions notebooks/occlusion_training_wrapup.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wrap-up codes for Occlusion Training"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"83"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from collections import defaultdict\n",
"\n",
"import wandb\n",
"\n",
"# Open target RoI indices\n",
"with open(\"assets/dkt_indices.txt\", mode=\"r\") as f:\n",
" gt_rois = {int(roi.strip()) for roi in f.readlines()}\n",
"\n",
"api = wandb.Api()\n",
"runs = api.runs(path=\"1pha/brain-age\",\n",
" filters={\"config.dataloader.dataset._target_\": \"sage.data.mask.UKB_MaskDataset\"})\n",
"len(runs)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"runs_dict = defaultdict(dict)\n",
"runs_idx = defaultdict(set)\n",
"for run in runs:\n",
" state = run.state\n",
" name: str = run.name.split(\" \")[1]\n",
" if not name.isnumeric():\n",
" # Skip outlier run name that does not follow `M XXXX | SEED`\n",
" continue\n",
" idx = int(name)\n",
" if state == \"finished\":\n",
" runs_dict[\"success\"][idx] = run\n",
" runs_idx[\"success\"].add(idx)\n",
" elif state == \"failed\":\n",
" runs_dict[\"fail\"][idx] = run\n",
" runs_idx[\"fail\"].add(idx)\n",
" else:\n",
" runs_dict[\"etc\"][idx] = (run, state)\n",
" runs_idx[\"etc\"].add(idx)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(66, 11, 3)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(runs_idx[\"success\"]), len(runs_idx[\"fail\"]), len(runs_idx[\"etc\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"assert gt_rois.issuperset(runs_idx[\"success\"])\n",
"leftover = gt_rois - runs_idx[\"success\"]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"34"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(leftover)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Actual failed runs: 9\n",
"1025\n",
"2018\n",
"2024\n",
"1002\n",
"77\n",
"2030\n",
"1009\n",
"54\n",
"24\n"
]
}
],
"source": [
"true_fail = {idx for idx in runs_idx[\"fail\"] if idx not in runs_idx[\"success\"]}\n",
"print(f\"Actual failed runs: {len(true_fail)}\")\n",
"for fail in true_fail:\n",
" print(fail)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Runs checked manually that has final test performance \n",
"manual_check_success: set = {1025, 2018, 2024, 1002, 77, 54, 24}\n",
"\n",
"# Was not able to get test performance but has final checkpoint\n",
"manual_check_ckpt: set = {2030, 1009}\n",
"\n",
"# Sanity check for empty runs\n",
"assert len(true_fail - (manual_check_ckpt | manual_check_success)) == 0\n",
"assert len((manual_check_ckpt | manual_check_success) - true_fail) == 0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"25"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"leftover = leftover - true_fail\n",
"len(leftover)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{2017}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"leftover & runs_idx[\"etc\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mask idx 2017 turned out to be crashed. Crahsed runs should be re-batched."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{2,\n",
" 46,\n",
" 47,\n",
" 49,\n",
" 50,\n",
" 251,\n",
" 252,\n",
" 253,\n",
" 254,\n",
" 255,\n",
" 1010,\n",
" 1011,\n",
" 1012,\n",
" 1013,\n",
" 1014,\n",
" 1015,\n",
" 1016,\n",
" 2005,\n",
" 2006,\n",
" 2007,\n",
" 2008,\n",
" 2017,\n",
" 2031,\n",
" 2034,\n",
" 2035}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"leftover"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Split runs and re-batch"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"def split_into_chunks(data_set, num_chunks):\n",
" # Convert set to list for shuffling\n",
" data_list = list(data_set)\n",
" \n",
" # Calculate the size of each chunk\n",
" chunk_size = len(data_list) // num_chunks\n",
" remainder = len(data_list) % num_chunks\n",
" \n",
" # Split the list into chunks\n",
" chunks = []\n",
" start = 0\n",
" for i in range(num_chunks):\n",
" chunk_length = chunk_size + (1 if i < remainder else 0)\n",
" chunks.append(data_list[start:start + chunk_length])\n",
" start += chunk_length\n",
" \n",
" flattened_chunks = [item for sublist in chunks for item in sublist]\n",
" assert set(flattened_chunks) == data_set, \"Chunks do not contain all elements of the data set\"\n",
" return chunks\n",
"\n",
"\n",
"machines = [\"185-0\", \"185-1\", \"245-0\", \"245-1\"]\n",
"chunks = split_into_chunks(data_set=leftover, num_chunks=4)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[2, 46, 47, 49, 50, 2005, 2006],\n",
" [2007, 2008, 2017, 2031, 2034, 2035],\n",
" [1010, 1011, 1012, 1013, 1016, 1014],\n",
" [1015, 251, 252, 253, 254, 255]]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chunks"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"for chunk, machine in zip(chunks, machines):\n",
" with open(f\"assets/dkt_leftover_{machine}.txt\", mode=\"w\") as f:\n",
" for roi in chunk:\n",
" f.write(f\"{roi}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "age",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b5cf589

Please sign in to comment.