From 4d50c5af34bbc5a33e3ed7ba39755e9e8ed07895 Mon Sep 17 00:00:00 2001 From: William Patton Date: Mon, 19 Aug 2024 07:35:01 -0700 Subject: [PATCH] always use cyto masks and plot membrane channel --- solution.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/solution.py b/solution.py index 61f8855..a4e305f 100644 --- a/solution.py +++ b/solution.py @@ -171,10 +171,10 @@ def compute_sdt(labels: np.ndarray, scale: int = 5): idx = np.random.randint(len(samples) // 3) # take a random sample. img = tifffile.imread(os.path.join(root_dir, f"img_{idx}.tif")) # get the image label = tifffile.imread( - os.path.join(root_dir, f"img_{idx}_nuclei_masks.tif") + os.path.join(root_dir, f"img_{idx}_cyto_masks.tif") ) # get the image sdt = compute_sdt(label) -plot_two(img[0], sdt, label="SDT") +plot_two(img[1], sdt, label="SDT") # %% [markdown] #
@@ -232,9 +232,7 @@ def __init__(self, root_dir, transform=None, img_transform=None, return_mask=Fal img_path = os.path.join(self.root_dir, f"img_{sample_ind}.tif") image = self.to_img(tifffile.imread(img_path)) self.loaded_imgs[sample_ind] = inp_transforms(image) - mask_path = os.path.join( - self.root_dir, f"img_{sample_ind}_nuclei_masks.tif" - ) + mask_path = os.path.join(self.root_dir, f"img_{sample_ind}_cyto_masks.tif") mask = self.to_img(tifffile.imread(mask_path)) self.loaded_masks[sample_ind] = mask @@ -293,9 +291,7 @@ def __init__(self, root_dir, transform=None, img_transform=None, return_mask=Fal img_path = os.path.join(self.root_dir, f"img_{sample_ind}.tif") image = self.to_img(tifffile.imread(img_path)) self.loaded_imgs[sample_ind] = inp_transforms(image) - mask_path = os.path.join( - self.root_dir, f"img_{sample_ind}_nuclei_masks.tif" - ) + mask_path = os.path.join(self.root_dir, f"img_{sample_ind}_cyto_masks.tif") mask = self.to_img(tifffile.imread(mask_path)) self.loaded_masks[sample_ind] = mask @@ -422,7 +418,7 @@ def create_sdt_target(self, mask): image = np.squeeze(image.cpu()) sdt = np.squeeze(sdt.cpu().numpy()) pred = np.squeeze(pred.cpu().detach().numpy()) -plot_three(image[0], sdt, pred) +plot_three(image[1], sdt, pred) # %% [markdown] @@ -577,7 +573,7 @@ def get_inner_mask(pred, threshold): # %% # Visualize the results -plot_four(image[0], mask, pred, seg, label="Target", cmap=label_cmap) +plot_four(image[1], mask, pred, seg, label="Target", cmap=label_cmap) # %% [markdown] #
@@ -770,7 +766,7 @@ def create_aff_target(self, mask): ) idx = np.random.randint(len(train_data)) # take a random sample img, affinity = train_data[idx] # get the image and the nuclei masks -plot_two(img[0], affinity[0] + affinity[1], label="AFFINITY") +plot_two(img[1], affinity[0+2] + affinity[1+2], label="AFFINITY") # %% [markdown] #
@@ -809,6 +805,7 @@ def create_aff_target(self, mask): loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate) +plot_three(image[1], mask[0] + mask[1], pred[0 + 2] + pred[1 + 2], label="Affinity") for epoch in range(NUM_EPOCHS): train( @@ -838,7 +835,7 @@ def create_aff_target(self, mask): mask = mask.cpu().numpy() pred = pred.cpu().detach().numpy() -plot_three(image[0], mask[0] + mask[1], pred[0] + pred[1], label="Affinity") +plot_three(image[1], mask[0] + mask[1], pred[0] + pred[1], label="Affinity") # %% [markdown] # Let's also evaluate the model performance.