From 2bfa33ad47c8d11ee2f425df6067c6727fb4487f Mon Sep 17 00:00:00 2001 From: nidefawl <637382+nidefawl@users.noreply.github.com> Date: Sat, 30 Dec 2023 00:58:34 +0100 Subject: [PATCH 1/2] Fix threshold in sam_predict --- modules/impact/core.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/modules/impact/core.py b/modules/impact/core.py index fd15dbdf..22d749aa 100644 --- a/modules/impact/core.py +++ b/modules/impact/core.py @@ -468,25 +468,10 @@ def sam_predict(predictor, points, plabs, bbox, threshold): cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) - total_masks = [] - - selected = False - max_score = 0 - for idx in range(len(scores)): - if scores[idx] > max_score: - max_score = scores[idx] - max_mask = cur_masks[idx] - - if scores[idx] >= threshold: - selected = True - total_masks.append(cur_masks[idx]) - else: - pass - - if not selected: - total_masks.append(max_mask) - - return total_masks + # take all 3 masks predict returns, or take none + if any([score >= threshold for score in scores]): + return [m for m in cur_masks] + return [] def make_sam_mask(sam_model, segs, image, detection_hint, dilation, From 60cd7be85acfa4a04e52450c68296acc6e06ea22 Mon Sep 17 00:00:00 2001 From: nidefawl <637382+nidefawl@users.noreply.github.com> Date: Sat, 30 Dec 2023 03:03:53 +0100 Subject: [PATCH 2/2] Fix size of empty mask --- modules/impact/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/impact/core.py b/modules/impact/core.py index 22d749aa..7e53a586 100644 --- a/modules/impact/core.py +++ b/modules/impact/core.py @@ -591,7 +591,7 @@ def make_sam_mask(sam_model, segs, image, detection_hint, dilation, mask = dilate_mask(mask.cpu().numpy(), dilation) mask = torch.from_numpy(mask) else: - mask = torch.zeros((8, 8), dtype=torch.float32, device="cpu") # empty mask + mask = torch.zeros((image.shape[0], image.shape[1]), dtype=torch.float32, device="cpu") # empty mask mask = utils.make_3d_mask(mask) return mask