Skip to content

Commit

Permalink
Cleaned up code
Browse files Browse the repository at this point in the history
  • Loading branch information
Arfan12630 committed Dec 20, 2024
1 parent a81ccba commit 8a0a184
Showing 1 changed file with 1 addition and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,6 @@ def __init__(self):
self.half = False



self.subscription = self.create_subscription(
Image if not self.compressed else CompressedImage,
self.camera_topic,
self.image_callback,
qos_profile=QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=10,
),
)
#Subscription for Nuscenes
self.front_center_camera_subscription = Subscriber(self, CompressedImage, self.front_center_camera_topic)
self.front_right_camera_subscription = Subscriber(self, CompressedImage, self.front_right_camera_topic)
Expand Down Expand Up @@ -181,9 +170,7 @@ def __init__(self):
#self.build_engine()
self.last_publish_time = self.get_clock().now()

self.model = AutoBackend(self.model_path, device=self.device, dnn=False, fp16=False)
self.names = self.model.module.names if hasattr(self.model, "module") else self.model.names
self.stride = int(self.model.stride)


#Batch vis publishers
self.batched_camera_message_publisher = self.create_publisher(BatchDetection,self.batch_inference_topic, 10)
Expand Down Expand Up @@ -598,8 +585,6 @@ def publish_batch_eve(self,image_list, decoded_results):
detection_array.detections.append(detection)
batch_msg.detections.append(detection_array)
annotated_image = annotator.result()


vis_image = Image()
vis_image.header.stamp = self.get_clock().now().to_msg()
vis_image.header.frame_id = f"camera_{idx}"
Expand All @@ -612,213 +597,6 @@ def publish_batch_eve(self,image_list, decoded_results):

# Publish batch detection message
self.batched_camera_message_publisher.publish(batch_msg)



def crop_image(self, cv_image):
if self.crop_mode == "LetterBox":
img = LetterBox(self.image_size, stride=self.stride)(image=cv_image)
elif self.crop_mode == "CenterCrop":
img = CenterCrop(self.image_size)(cv_image)
else:
raise Exception("Invalid crop mode, please choose either 'LetterBox' or 'CenterCrop'!")

return img

def convert_bboxes_to_orig_frame(self, bbox):
"""
Converts bounding box coordinates from the scaled image frame back to the original image frame.
This function takes into account the original image dimensions and the scaling method used
(either "LetterBox" or "CenterCrop") to accurately map the bounding box coordinates back to
their original positions in the original image.
Parameters:
bbox (list): A list containing the bounding box coordinates in the format [x1, y1, w1, h1]
in the scaled image frame.
Returns:
list: A list containing the bounding box coordinates in the format [x1, y1, w1, h1]
in the original image frame.
"""
width_scale = self.orig_image_width / self.image_size
height_scale = self.orig_image_height / self.image_size
if self.crop_mode == "LetterBox":
translation = (self.image_size - self.orig_image_height / width_scale) / 2
return [
bbox[0] * width_scale,
(bbox[1] - translation) * width_scale,
bbox[2] * width_scale,
bbox[3] * width_scale,
]
elif self.crop_mode == "CenterCrop":
translation = (self.orig_image_width / height_scale - self.image_size) / 2
return [
(bbox[0] + translation) * height_scale,
bbox[1] * height_scale,
bbox[2] * height_scale,
bbox[3] * height_scale,
]

def crop_and_convert_to_tensor(self, cv_image):
"""
Preprocess the image by resizing, padding and rearranging the dimensions.
Parameters:
cv_image: A numpy or cv2 image of shape (w,h,3)
Returns:
torch.Tensor image for model input of shape (1,3,w,h)
"""
img = self.crop_image(cv_image)

# Convert
img = cv_image.transpose(2, 0, 1)

# Further conversion
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
img = img.unsqueeze(0)

return img

def postprocess_detections(self, detections, annotator):
"""
Post-process draws bouningboxes on camera image.
Parameters:
detections: A list of dict with the format
{
"label": str,
"bbox": [float],
"conf": float
}
annotator: A ultralytics.yolo.utils.plotting.Annotator for the current image
Returns:
processed_detections: filtered detections
annotator_img: image with bounding boxes drawn on
"""
processed_detections = detections

for det in detections:
label = f'{det["label"]} {det["conf"]:.2f}'
x1, y1, w1, h1 = det["bbox"]
xyxy = [x1, y1, x1 + w1, y1 + h1]
annotator.box_label(xyxy, label, color=colors(1, True))

annotator_img = annotator.result()
return (processed_detections, annotator_img)

def publish_vis(self, annotated_img, msg, feed):
# Publish visualizations
imgmsg = self.cv_bridge.cv2_to_imgmsg(annotated_img, "bgr8")
imgmsg.header.stamp = msg.header.stamp
imgmsg.header.frame_id = msg.header.frame_id
self.vis_publisher.publish(imgmsg)

def publish_detections(self, detections, msg, feed):
# Publish detections to an detectionList message
detection2darray = Detection2DArray()

# fill header for detection list
detection2darray.header.stamp = msg.header.stamp
detection2darray.header.frame_id = msg.header.frame_id
# populate detection list
if detections is not None and len(detections):
for detection in detections:
detection2d = Detection2D()
detection2d.header.stamp = msg.header.stamp
detection2d.header.frame_id = msg.header.frame_id
detected_object = ObjectHypothesisWithPose()
detected_object.hypothesis.class_id = detection["label"]
detected_object.hypothesis.score = detection["conf"]
detection2d.results.append(detected_object)
detection2d.bbox.center.position.x = detection["bbox"][0]
detection2d.bbox.center.position.y = detection["bbox"][1]
detection2d.bbox.size_x = detection["bbox"][2]
detection2d.bbox.size_y = detection["bbox"][3]

# append detection to detection list
detection2darray.detections.append(detection2d)

self.detection_publisher.publish(detection2darray)

def image_callback(self, msg):
self.get_logger().debug("Received image")
if self.orig_image_width is None:
self.orig_image_width = msg.width
self.orig_image_height = msg.height

images = [msg] # msg is a single sensor image
startTime = time.time()
for image in images:

# convert ros Image to cv::Mat
if self.compressed:
np_arr = np.frombuffer(msg.data, np.uint8)
cv_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
else:
try:
cv_image = self.cv_bridge.imgmsg_to_cv2(image, desired_encoding="passthrough")
except CvBridgeError as e:
self.get_logger().error(str(e))
return

# preprocess image and run through prediction
img = self.crop_and_convert_to_tensor(cv_image)
pred = self.model(img)

# nms function used same as yolov8 detect.py
pred = non_max_suppression(pred) #Eliminates overlapping bounding boxes
detections = []
for i, det in enumerate(pred): # per image
if len(det):
# Write results
for *xyxy, conf, cls in reversed(det):
label = self.names[int(cls)]

bbox = [
xyxy[0],
xyxy[1],
xyxy[2] - xyxy[0],
xyxy[3] - xyxy[1],
]
bbox = [b.item() for b in bbox]
bbox = self.convert_bboxes_to_orig_frame(bbox)

detections.append(
{
"label": label,
"conf": conf.item(),
"bbox": bbox,
}
)
self.get_logger().debug(f"{label}: {bbox}")

annotator = Annotator(
cv_image,
line_width=self.line_thickness,
example=str(self.names),
)
(detections, annotated_img) = self.postprocess_detections(detections, annotator)

# Currently we support a single camera so we pass an empty string
feed = ""
self.publish_vis(annotated_img, msg, feed)
self.publish_detections(detections, msg, feed)

if self.save_detections:
cv2.imwrite(f"detections/{self.counter}.jpg", annotated_img)
self.counter += 1

self.get_logger().info(
f"Finished in: {time.time() - startTime}, {1/(time.time() - startTime)} Hz"
)


def main(args=None):
rclpy.init(args=args)

Expand Down

0 comments on commit 8a0a184

Please sign in to comment.