Skip to content

Commit

Permalink
final compute node specific edits
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 23, 2024
1 parent 3b7a106 commit b12c749
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions solution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# %% [markdown]
# # Exercise 05: Instance Segmentation
# # Exercise 05: Instance Segmentation :)
#
# So far, we were only interested in `semantic` classes, e.g. foreground / background etc.
# But in many cases we not only want to know if a certain pixel belongs to a specific class, but also to which unique object (i.e. the task of `instance segmentation`).
Expand Down Expand Up @@ -40,12 +40,11 @@
import numpy as np
import os
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import v2
from scipy.ndimage import distance_transform_edt, map_coordinates
from local import train, NucleiDataset, plot_two, plot_three, plot_four
from local import train, plot_two, plot_three, plot_four
from dlmbl_unet import UNet
from tqdm import tqdm
import tifffile
Expand All @@ -58,11 +57,11 @@
# %%
# Set some variables that are specific to the hardware being run on
# this should be optimized for the compute nodes once available.
device = "cpu" # 'cuda', 'cpu', 'mps'
NUM_THREADS = 0
NUM_EPOCHS = 20
device = "cuda" # 'cuda', 'cpu', 'mps'
NUM_THREADS = 8
NUM_EPOCHS = 200
# make sure gpu is available. Please call a TA if this cell fails
# assert torch.cuda.is_available()
assert torch.cuda.is_available()

# %%
# Create a custom label color map for showing instances
Expand Down Expand Up @@ -228,7 +227,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5):
# Modify the `SDTDataset` class below to produce the paired raw and SDT images.<br>
# 1. Fill in the `create_sdt_target` method to return an SDT output rather than a label mask.<br>
# - Ensure that all final outputs are of torch tensor type, and are converted to float.
# 2. Instantiate the dataset with a RandomCrop of size 256 and visualize the output to confirm that the SDT is correct.
# 2. Instantiate the dataset with a RandomCrop of size 128 and visualize the output to confirm that the SDT is correct.
# </div>


Expand Down Expand Up @@ -322,7 +321,7 @@ def __init__(self, root_dir, transform=None, img_transform=None, return_mask=Fal

self.loaded_imgs = [None] * self.num_samples
self.loaded_masks = [None] * self.num_samples
for sample_ind in tqdm(range(self.num_samples), desc="Reaqding Images"):
for sample_ind in tqdm(range(self.num_samples), desc="Reading Images"):
img_path = os.path.join(self.root_dir, f"img_{sample_ind}.tif")
image = self.from_np(tifffile.imread(img_path))
self.loaded_imgs[sample_ind] = inp_transforms(image)
Expand Down Expand Up @@ -366,7 +365,7 @@ def create_sdt_target(self, mask):


# %% tags=["task"]
# Create a dataset using a RandomCrop of size 256 (see torchvision.transforms.v2 imported as v2)
# Create a dataset using a RandomCrop of size 128 (see torchvision.transforms.v2 imported as v2)
# documentation here: https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended
# Visualize the output to confirm your dataset is working.

Expand All @@ -378,11 +377,11 @@ def create_sdt_target(self, mask):
plot_two(img, sdt[0], label="SDT")

# %% tags=["solution"]
# Create a dataset using a RandomCrop of size 256 (see torchvision.transforms.v2 imported as v2)
# Create a dataset using a RandomCrop of size 128 (see torchvision.transforms.v2 imported as v2)
# documentation here: https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended
# Visualize the output to confirm your dataset is working.

train_data = SDTDataset("tissuenet_data/train", v2.RandomCrop(256))
train_data = SDTDataset("tissuenet_data/train", v2.RandomCrop(128))
img, sdt = train_data[10] # get the image and the distance transform
# We use the `plot_two` function (imported in the first cell) to verify that our
# dataset solution is correct. The output should show 2 images: the raw image and
Expand Down Expand Up @@ -422,7 +421,7 @@ def create_sdt_target(self, mask):
# %%
# TODO: You don't have to add extra augmentations, training will work without.
# But feel free to experiment here if you want to come back and try to get better results if you have time.
train_data = SDTDataset("tissuenet_data/train", v2.RandomCrop(256))
train_data = SDTDataset("tissuenet_data/train", v2.RandomCrop(128))
train_loader = DataLoader(
train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS
)
Expand Down Expand Up @@ -1007,7 +1006,7 @@ def create_aff_target(self, mask):
neighborhood = [[0, 1], [1, 0], [0, 5], [5, 0]]
train_data = AffinityDataset(
"tissuenet_data/train",
v2.RandomCrop(256),
v2.RandomCrop(128),
weights=True,
neighborhood=neighborhood,
)
Expand Down Expand Up @@ -1089,7 +1088,7 @@ def create_aff_target(self, mask):
# It can also be useful to bias long range affinities more negatively than the short range affinities. The intuition here being that boundaries are often blurry in biology. This means it may not be easy to tell if the neighboring pixel has crossed a boundary, but it is reasonably easy to tell if there is a boundary accross a 5 pixel gap. Similarly, identifying if two pixels belong to the same object is easier, the closer they are to each other. Providing a more negative bias to long range affinities means we bias towards splitting on low long range affinities, and merging on high short range affinities.

# %%
val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256), return_mask=True)
val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(128), return_mask=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=8)

unet.eval()
Expand Down Expand Up @@ -1222,7 +1221,7 @@ def create_aff_target(self, mask):
# %% tags=["solution"]
from cellpose import models

model = models.Cellpose(model_type="nuclei")
model = models.Cellpose(model_type="cyto3")
channels = [[0, 0]]

precision_list, recall_list, accuracy_list = [], [], []
Expand Down

0 comments on commit b12c749

Please sign in to comment.