-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathtrackkkkk.py
153 lines (118 loc) · 5.31 KB
/
trackkkkk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import hydra
import torch
import cv2
from random import randint
from sort import *
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
tracker = None
def init_tracker():
global tracker
sort_max_age = 5
sort_min_hits = 2
sort_iou_thresh = 0.2
tracker =Sort(max_age=sort_max_age,min_hits=sort_min_hits,iou_threshold=sort_iou_thresh)
rand_color_list = []
def draw_boxes(img, bbox, identities=None, categories=None, names=None, offset=(0, 0)):
for i, box in enumerate(bbox):
x1, y1, x2, y2 = [int(i) for i in box]
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
id = int(identities[i]) if identities is not None else 0
box_center = (int((box[0]+box[2])/2),(int((box[1]+box[3])/2)))
label = str(id)
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 253), 2)
cv2.rectangle(img, (x1, y1 - 20), (x1 + w, y1), (255,144,30), -1)
cv2.putText(img, label, (x1, y1 - 5),cv2.FONT_HERSHEY_SIMPLEX, 0.6, [255, 255, 255], 1)
return img
def random_color_list():
global rand_color_list
rand_color_list = []
for i in range(0,5005):
r = randint(0, 255)
g = randint(0, 255)
b = randint(0, 255)
rand_color = (r, g, b)
rand_color_list.append(rand_color)
#......................................
class DetectionPredictor(BasePredictor):
def get_annotator(self, img):
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
def preprocess(self, img):
img = torch.from_numpy(img).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def postprocess(self, preds, img, orig_img):
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det)
for i, pred in enumerate(preds):
shape = orig_img[i].shape if self.webcam else orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
return preds
def write_results(self, idx, preds, batch):
p, im, im0 = batch
log_string = ""
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
im0 = im0.copy()
if self.webcam: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
# tracker
self.data_path = p
save_path = str(self.save_dir / p.name) # im.jpg
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
log_string += '%gx%g ' % im.shape[2:] # print string
self.annotator = self.get_annotator(im0)
det = preds[idx]
self.all_outputs.append(det)
if len(det) == 0:
return log_string
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
# #..................USE TRACK FUNCTION....................
dets_to_sort = np.empty((0,6))
for x1,y1,x2,y2,conf,detclass in det.cpu().detach().numpy():
dets_to_sort = np.vstack((dets_to_sort,
np.array([x1, y1, x2, y2, conf, detclass])))
tracked_dets = tracker.update(dets_to_sort)
tracks =tracker.getTrackers()
for track in tracks:
[cv2.line(im0, (int(track.centroidarr[i][0]),
int(track.centroidarr[i][1])),
(int(track.centroidarr[i+1][0]),
int(track.centroidarr[i+1][1])),
rand_color_list[track.id], thickness=3)
for i,_ in enumerate(track.centroidarr)
if i < len(track.centroidarr)-1 ]
if len(tracked_dets)>0:
bbox_xyxy = tracked_dets[:,:4]
identities = tracked_dets[:, 8]
categories = tracked_dets[:, 4]
draw_boxes(im0, bbox_xyxy, identities, categories, self.model.names)
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
init_tracker()
random_color_list()
cfg.model = cfg.model or "yolov8n.pt"
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
predictor = DetectionPredictor(cfg)
predictor()
if __name__ == "__main__":
predict()