Skip to content

Commit

Permalink
always use cyto masks and plot membrane channel
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 19, 2024
1 parent 9c01234 commit 4d50c5a
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def compute_sdt(labels: np.ndarray, scale: int = 5):
idx = np.random.randint(len(samples) // 3) # take a random sample.
img = tifffile.imread(os.path.join(root_dir, f"img_{idx}.tif")) # get the image
label = tifffile.imread(
os.path.join(root_dir, f"img_{idx}_nuclei_masks.tif")
os.path.join(root_dir, f"img_{idx}_cyto_masks.tif")
) # get the image
sdt = compute_sdt(label)
plot_two(img[0], sdt, label="SDT")
plot_two(img[1], sdt, label="SDT")

# %% [markdown]
# <div class="alert alert-block alert-info">
Expand Down Expand Up @@ -232,9 +232,7 @@ def __init__(self, root_dir, transform=None, img_transform=None, return_mask=Fal
img_path = os.path.join(self.root_dir, f"img_{sample_ind}.tif")
image = self.to_img(tifffile.imread(img_path))
self.loaded_imgs[sample_ind] = inp_transforms(image)
mask_path = os.path.join(
self.root_dir, f"img_{sample_ind}_nuclei_masks.tif"
)
mask_path = os.path.join(self.root_dir, f"img_{sample_ind}_cyto_masks.tif")
mask = self.to_img(tifffile.imread(mask_path))
self.loaded_masks[sample_ind] = mask

Expand Down Expand Up @@ -293,9 +291,7 @@ def __init__(self, root_dir, transform=None, img_transform=None, return_mask=Fal
img_path = os.path.join(self.root_dir, f"img_{sample_ind}.tif")
image = self.to_img(tifffile.imread(img_path))
self.loaded_imgs[sample_ind] = inp_transforms(image)
mask_path = os.path.join(
self.root_dir, f"img_{sample_ind}_nuclei_masks.tif"
)
mask_path = os.path.join(self.root_dir, f"img_{sample_ind}_cyto_masks.tif")
mask = self.to_img(tifffile.imread(mask_path))
self.loaded_masks[sample_ind] = mask

Expand Down Expand Up @@ -422,7 +418,7 @@ def create_sdt_target(self, mask):
image = np.squeeze(image.cpu())
sdt = np.squeeze(sdt.cpu().numpy())
pred = np.squeeze(pred.cpu().detach().numpy())
plot_three(image[0], sdt, pred)
plot_three(image[1], sdt, pred)


# %% [markdown]
Expand Down Expand Up @@ -577,7 +573,7 @@ def get_inner_mask(pred, threshold):
# %%
# Visualize the results

plot_four(image[0], mask, pred, seg, label="Target", cmap=label_cmap)
plot_four(image[1], mask, pred, seg, label="Target", cmap=label_cmap)

# %% [markdown]
# <div class="alert alert-block alert-info">
Expand Down Expand Up @@ -770,7 +766,7 @@ def create_aff_target(self, mask):
)
idx = np.random.randint(len(train_data)) # take a random sample
img, affinity = train_data[idx] # get the image and the nuclei masks
plot_two(img[0], affinity[0] + affinity[1], label="AFFINITY")
plot_two(img[1], affinity[0+2] + affinity[1+2], label="AFFINITY")

# %% [markdown]
# <div class="alert alert-block alert-info">
Expand Down Expand Up @@ -809,6 +805,7 @@ def create_aff_target(self, mask):
loss = torch.nn.MSELoss()

optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
plot_three(image[1], mask[0] + mask[1], pred[0 + 2] + pred[1 + 2], label="Affinity")

for epoch in range(NUM_EPOCHS):
train(
Expand Down Expand Up @@ -838,7 +835,7 @@ def create_aff_target(self, mask):
mask = mask.cpu().numpy()
pred = pred.cpu().detach().numpy()

plot_three(image[0], mask[0] + mask[1], pred[0] + pred[1], label="Affinity")
plot_three(image[1], mask[0] + mask[1], pred[0] + pred[1], label="Affinity")

# %% [markdown]
# Let's also evaluate the model performance.
Expand Down

0 comments on commit 4d50c5a

Please sign in to comment.