Skip to content

Commit

Permalink
Enhance predictions by incorporating center points from neighboring p…
Browse files Browse the repository at this point in the history
…olygons and considering points outside the current polygon as negative prompt points.
  • Loading branch information
healthonrails committed Jan 10, 2024
1 parent fa98aa8 commit a6a06c0
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ def random_sample_inside_edges(polygon, num_points):
return np.array(sampled_points)


def random_sample_outside_edges(polygon, num_points):
# Randomly sample points inside the edges of the polygon
sampled_points = []
min_x, min_y, max_x, max_y = polygon.bounds

for _ in range(num_points):
# Generate random point inside the bounding box
x = np.random.uniform(min_x, max_x)
y = np.random.uniform(min_y, max_y)
point = Point(x, y)

# Check if the point is inside the polygon
if not point.within(polygon):
sampled_points.append((x, y))

return np.array(sampled_points)


def find_bbox(polygon_points):
# Convert the list of polygon points to a NumPy array
points_array = np.array(polygon_points)
Expand Down Expand Up @@ -233,14 +251,24 @@ def process_frame(self, frame_number):
# Randomly sample points inside the edges of the polygon
points_inside_edges = random_sample_inside_edges(polygon,
self.num_points_inside_edges)
points_outside_edges = random_sample_outside_edges(polygon,
self.num_points_inside_edges * 3
)
points_uni = uniform_points_inside_polygon(
polygon, self.num_points_inside_edges)
center_points = self.center_points_dict.get(label,
MaxSizeQueue(max_size=self.num_center_points))

center_points.enqueue(points[0])
points = center_points.to_numpy()
self.center_points_dict[label] = center_points

# use other instance's center points as negative point prompts
other_polygon_center_points = [
value for k, v in self.center_points_dict.items() if k != label for value in v]
other_polygon_center_points = np.array(
[(x[0], x[1]) for x in other_polygon_center_points])

if len(points_inside_edges.shape) > 1:
points = np.concatenate(
(points, points_inside_edges), axis=0)
Expand All @@ -250,6 +278,19 @@ def process_frame(self, frame_number):
)

point_labels = [1] * len(points)
if len(points_outside_edges) > 1:
points = np.concatenate(
(points, points_outside_edges), axis=0
)
point_labels += [0] * len(points_outside_edges)

if len(other_polygon_center_points) > 1:
points = np.concatenate(
(points, other_polygon_center_points),
axis=0
)
point_labels += [0] * len(other_polygon_center_points)

polygon = self.edge_sam.predict_polygon_from_points(
points, point_labels)

Expand Down

0 comments on commit a6a06c0

Please sign in to comment.