diff --git a/solution.py b/solution.py
index a4e305f..cb05ea2 100644
--- a/solution.py
+++ b/solution.py
@@ -625,9 +625,10 @@ def get_inner_mask(pred, threshold):
# %% [markdown]
#
-# Task 3.2:
Evaluate metrics for the validation dataset.
+# Task 3.2:
Evaluate metrics for the validation dataset. Fill in the blanks
#
-# %%
+
+# %% tags=["task"]
from local import evaluate
# Need to re-initialize the dataloader to return masks in addition to SDTs.
@@ -671,6 +672,49 @@ def get_inner_mask(pred, threshold):
print(f"Mean Precision is {np.mean(precision_list):.3f}")
print(f"Mean Recall is {np.mean(recall_list):.3f}")
print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")
+
+# %% tags=["solution"]
+from local import evaluate
+
+# Need to re-initialize the dataloader to return masks in addition to SDTs.
+val_dataset = SDTDataset("tissuenet_data/test", return_mask=True)
+val_dataloader = DataLoader(
+ val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS
+)
+unet.eval()
+
+(
+ precision_list,
+ recall_list,
+ accuracy_list,
+) = (
+ [],
+ [],
+ [],
+)
+for idx, (image, mask, sdt) in enumerate(tqdm(val_dataloader)):
+ image = image.to(device)
+ pred = unet(image)
+
+ image = np.squeeze(image.cpu())
+ gt_labels = np.squeeze(mask.cpu().numpy())
+ pred = np.squeeze(pred.cpu().detach().numpy())
+
+ # feel free to try different thresholds
+ thresh = ...
+
+ # get boundary mask
+ inner_mask = ...
+ pred_labels = ...
+ precision, recall, accuracy = evaluate(gt_labels, pred_labels)
+ precision_list.append(precision)
+ recall_list.append(recall)
+ accuracy_list.append(accuracy)
+
+print(f"Mean Precision is {np.mean(precision_list):.3f}")
+print(f"Mean Recall is {np.mean(recall_list):.3f}")
+print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")
+
# %% [markdown]
#
#