diff --git a/solution.py b/solution.py index a4e305f..cb05ea2 100644 --- a/solution.py +++ b/solution.py @@ -625,9 +625,10 @@ def get_inner_mask(pred, threshold): # %% [markdown] #
-# Task 3.2:
Evaluate metrics for the validation dataset. +# Task 3.2:
Evaluate metrics for the validation dataset. Fill in the blanks #
-# %% + +# %% tags=["task"] from local import evaluate # Need to re-initialize the dataloader to return masks in addition to SDTs. @@ -671,6 +672,49 @@ def get_inner_mask(pred, threshold): print(f"Mean Precision is {np.mean(precision_list):.3f}") print(f"Mean Recall is {np.mean(recall_list):.3f}") print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}") + +# %% tags=["solution"] +from local import evaluate + +# Need to re-initialize the dataloader to return masks in addition to SDTs. +val_dataset = SDTDataset("tissuenet_data/test", return_mask=True) +val_dataloader = DataLoader( + val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS +) +unet.eval() + +( + precision_list, + recall_list, + accuracy_list, +) = ( + [], + [], + [], +) +for idx, (image, mask, sdt) in enumerate(tqdm(val_dataloader)): + image = image.to(device) + pred = unet(image) + + image = np.squeeze(image.cpu()) + gt_labels = np.squeeze(mask.cpu().numpy()) + pred = np.squeeze(pred.cpu().detach().numpy()) + + # feel free to try different thresholds + thresh = ... + + # get boundary mask + inner_mask = ... + pred_labels = ... + precision, recall, accuracy = evaluate(gt_labels, pred_labels) + precision_list.append(precision) + recall_list.append(recall) + accuracy_list.append(accuracy) + +print(f"Mean Precision is {np.mean(precision_list):.3f}") +print(f"Mean Recall is {np.mean(recall_list):.3f}") +print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}") + # %% [markdown] #
#