From a6a06c0cb5022ef842fb2affe7d6077b6524ade0 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Wed, 10 Jan 2024 13:29:57 -0500 Subject: [PATCH] Enhance predictions by incorporating center points from neighboring polygons and considering points outside the current polygon as negative prompt points. --- annolid/segmentation/SAM/edge_sam_bg.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/annolid/segmentation/SAM/edge_sam_bg.py b/annolid/segmentation/SAM/edge_sam_bg.py index 38fdff42..1e7f6384 100644 --- a/annolid/segmentation/SAM/edge_sam_bg.py +++ b/annolid/segmentation/SAM/edge_sam_bg.py @@ -72,6 +72,24 @@ def random_sample_inside_edges(polygon, num_points): return np.array(sampled_points) +def random_sample_outside_edges(polygon, num_points): + # Randomly sample points inside the edges of the polygon + sampled_points = [] + min_x, min_y, max_x, max_y = polygon.bounds + + for _ in range(num_points): + # Generate random point inside the bounding box + x = np.random.uniform(min_x, max_x) + y = np.random.uniform(min_y, max_y) + point = Point(x, y) + + # Check if the point is inside the polygon + if not point.within(polygon): + sampled_points.append((x, y)) + + return np.array(sampled_points) + + def find_bbox(polygon_points): # Convert the list of polygon points to a NumPy array points_array = np.array(polygon_points) @@ -233,14 +251,24 @@ def process_frame(self, frame_number): # Randomly sample points inside the edges of the polygon points_inside_edges = random_sample_inside_edges(polygon, self.num_points_inside_edges) + points_outside_edges = random_sample_outside_edges(polygon, + self.num_points_inside_edges * 3 + ) points_uni = uniform_points_inside_polygon( polygon, self.num_points_inside_edges) center_points = self.center_points_dict.get(label, MaxSizeQueue(max_size=self.num_center_points)) + center_points.enqueue(points[0]) points = center_points.to_numpy() self.center_points_dict[label] = center_points + # use other instance's center points as negative point prompts + other_polygon_center_points = [ + value for k, v in self.center_points_dict.items() if k != label for value in v] + other_polygon_center_points = np.array( + [(x[0], x[1]) for x in other_polygon_center_points]) + if len(points_inside_edges.shape) > 1: points = np.concatenate( (points, points_inside_edges), axis=0) @@ -250,6 +278,19 @@ def process_frame(self, frame_number): ) point_labels = [1] * len(points) + if len(points_outside_edges) > 1: + points = np.concatenate( + (points, points_outside_edges), axis=0 + ) + point_labels += [0] * len(points_outside_edges) + + if len(other_polygon_center_points) > 1: + points = np.concatenate( + (points, other_polygon_center_points), + axis=0 + ) + point_labels += [0] * len(other_polygon_center_points) + polygon = self.edge_sam.predict_polygon_from_points( points, point_labels)