Skip to content

Commit

Permalink
add task/solution split for task 3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 19, 2024
1 parent 4d50c5a commit 506d8be
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,10 @@ def get_inner_mask(pred, threshold):

# %% [markdown]
# <div class="alert alert-block alert-info">
# <b>Task 3.2</b>: <br> Evaluate metrics for the validation dataset.
# <b>Task 3.2</b>: <br> Evaluate metrics for the validation dataset. Fill in the blanks
# </div>
# %%

# %% tags=["task"]
from local import evaluate

# Need to re-initialize the dataloader to return masks in addition to SDTs.
Expand Down Expand Up @@ -671,6 +672,49 @@ def get_inner_mask(pred, threshold):
print(f"Mean Precision is {np.mean(precision_list):.3f}")
print(f"Mean Recall is {np.mean(recall_list):.3f}")
print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")

# %% tags=["solution"]
from local import evaluate

# Need to re-initialize the dataloader to return masks in addition to SDTs.
val_dataset = SDTDataset("tissuenet_data/test", return_mask=True)
val_dataloader = DataLoader(
val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS
)
unet.eval()

(
precision_list,
recall_list,
accuracy_list,
) = (
[],
[],
[],
)
for idx, (image, mask, sdt) in enumerate(tqdm(val_dataloader)):
image = image.to(device)
pred = unet(image)

image = np.squeeze(image.cpu())
gt_labels = np.squeeze(mask.cpu().numpy())
pred = np.squeeze(pred.cpu().detach().numpy())

# feel free to try different thresholds
thresh = ...

# get boundary mask
inner_mask = ...
pred_labels = ...
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
accuracy_list.append(accuracy)

print(f"Mean Precision is {np.mean(precision_list):.3f}")
print(f"Mean Recall is {np.mean(recall_list):.3f}")
print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")

# %% [markdown]
# <hr style="height:2px;">
#
Expand Down

0 comments on commit 506d8be

Please sign in to comment.