Skip to content

Commit

Permalink
fix: Improve prompt handling and rectangle shape prediction in Canvas
Browse files Browse the repository at this point in the history
- Added check to ensure prompt is not None before processing
- Refactored logic to handle cases where rectangle_shapes is None
- Initialized AI model only when rectangle_shapes is None
  • Loading branch information
healthonrails committed Jan 10, 2025
1 parent 753ef7d commit 4636eae
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions annolid/gui/widgets/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def predictAiRectangle(self, prompt,
self.sam_hq_model = SamHQSegmenter()

# If the prompt contains 'every', segment everything
if 'every' in prompt.lower():
if prompt and 'every' in prompt.lower():
points_per_side, prompt = extract_number_and_remove_digits(prompt)
if points_per_side < 1 or points_per_side > 100:
points_per_side = 32
Expand All @@ -300,21 +300,22 @@ def predictAiRectangle(self, prompt,
is_polygon_output=is_polygon_output)
return

if rectangle_shapes is None:
rectangle_shapes = []
_bboxes = self._predict_similar_rectangles(
rectangle_shapes=rectangle_shapes, prompt=prompt)

label = prompt

# Initialize AI model if not already initialized
if self._ai_model_rect is None:
self._ai_model_rect = GroundingDINO()

# # Predict bounding boxes using the AI model
bboxes = self._ai_model_rect.predict_bboxes(image_data, prompt)
gd_bboxes = [list(box) for box, _ in bboxes]
_bboxes.extend(gd_bboxes)
if rectangle_shapes is not None:
_bboxes = self._predict_similar_rectangles(
rectangle_shapes=rectangle_shapes, prompt=prompt)
else:
rectangle_shapes = []
_bboxes = self._predict_similar_rectangles(
rectangle_shapes=rectangle_shapes, prompt=prompt)
# Initialize AI model if not already initialized
if self._ai_model_rect is None:
self._ai_model_rect = GroundingDINO()

# # Predict bounding boxes using the AI model
bboxes = self._ai_model_rect.predict_bboxes(image_data, prompt)
gd_bboxes = [list(box) for box, _ in bboxes]
_bboxes.extend(gd_bboxes)

# Segment objects using SAM HQ model with predicted bounding boxes
masks, scores, _bboxes = self.sam_hq_model.segment_objects(
Expand Down

0 comments on commit 4636eae

Please sign in to comment.