yolov8.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. from tinygrad.nn import Conv2d, BatchNorm2d
  2. from tinygrad.tensor import Tensor
  3. import numpy as np
  4. from itertools import chain
  5. from pathlib import Path
  6. import cv2
  7. from collections import defaultdict
  8. import time, sys
  9. from tinygrad.helpers import fetch
  10. from tinygrad.nn.state import safe_load, load_state_dict
  11. #Model architecture from https://github.com/ultralytics/ultralytics/issues/189
  12. #The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
  13. #Pre processing image functions.
  14. def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32) -> Tensor:
  15. shape = image.shape[:2] # current shape [height, width]
  16. new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
  17. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  18. r = min(r, 1.0) if not scaleup else r
  19. new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r)))
  20. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
  21. dw, dh = (np.mod(dw, stride), np.mod(dh, stride)) if auto else (0.0, 0.0)
  22. new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
  23. dw /= 2
  24. dh /= 2
  25. image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image
  26. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  27. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  28. image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
  29. return Tensor(image)
  30. def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
  31. same_shapes = all(x.shape == im[0].shape for x in im)
  32. auto = same_shapes and model_pt
  33. im = [compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im]
  34. im = Tensor.stack(*im) if len(im) > 1 else im[0].unsqueeze(0)
  35. im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
  36. im = im / 255.0 # 0 - 255 to 0.0 - 1.0
  37. return im
  38. # Post Processing functions
  39. def box_area(box):
  40. return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
  41. def box_iou(box1, box2):
  42. lt = np.maximum(box1[:, None, :2], box2[:, :2])
  43. rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
  44. wh = np.clip(rb - lt, 0, None)
  45. inter = wh[:, :, 0] * wh[:, :, 1]
  46. area1 = box_area(box1)[:, None]
  47. area2 = box_area(box2)[None, :]
  48. iou = inter / (area1 + area2 - inter)
  49. return iou
  50. def compute_nms(boxes, scores, iou_threshold):
  51. order, keep = scores.argsort()[::-1], []
  52. while order.size > 0:
  53. i = order[0]
  54. keep.append(i)
  55. if order.size == 1:
  56. break
  57. iou = box_iou(boxes[i][None, :], boxes[order[1:]])
  58. inds = np.where(iou.squeeze() <= iou_threshold)[0]
  59. order = order[inds + 1]
  60. return np.array(keep)
  61. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
  62. prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
  63. bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
  64. xc = np.amax(prediction[:, 4:4 + nc], axis=1) > conf_thres
  65. nm = prediction.shape[1] - nc - 4
  66. output = [np.zeros((0, 6 + nm))] * bs
  67. for xi, x in enumerate(prediction):
  68. x = x.swapaxes(0, -1)[xc[xi]]
  69. if not x.shape[0]: continue
  70. box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
  71. conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
  72. x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
  73. x = x[conf.ravel() > conf_thres]
  74. if not x.shape[0]: continue
  75. x = x[np.argsort(-x[:, 4])]
  76. c = x[:, 5:6] * (0 if agnostic else max_wh)
  77. boxes, scores = x[:, :4] + c, x[:, 4]
  78. i = compute_nms(boxes, scores, iou_thres)[:max_det]
  79. output[xi] = x[i]
  80. return output
  81. def postprocess(preds, img, orig_imgs):
  82. print('copying to CPU now for post processing')
  83. #if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
  84. # TODO: make non_max_suppression in tinygrad - to make this faster
  85. preds = preds.numpy() if isinstance(preds, Tensor) else preds
  86. preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
  87. all_preds = []
  88. for i, pred in enumerate(preds):
  89. orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
  90. if not isinstance(orig_imgs, Tensor):
  91. pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  92. all_preds.append(pred)
  93. return all_preds
  94. def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
  95. color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
  96. font = cv2.FONT_HERSHEY_SIMPLEX
  97. def is_bright_color(color):
  98. r, g, b = color
  99. brightness = (r * 299 + g * 587 + b * 114) / 1000
  100. return brightness > 127
  101. for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
  102. predictions = np.array(predictions)
  103. orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
  104. height, width, _ = orig_img.shape
  105. box_thickness = int((height + width) / 400)
  106. font_scale = (height + width) / 2500
  107. grouped_preds = defaultdict(list)
  108. object_count = defaultdict(int)
  109. for pred_np in predictions:
  110. grouped_preds[int(pred_np[-1])].append(pred_np)
  111. def draw_box_and_label(pred, color):
  112. x1, y1, x2, y2, conf, _ = pred
  113. x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
  114. cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
  115. label = f"{class_labels[class_id]} {conf:.2f}"
  116. text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
  117. label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
  118. cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
  119. font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
  120. cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
  121. for class_id, pred_list in grouped_preds.items():
  122. pred_list = np.array(pred_list)
  123. while len(pred_list) > 0:
  124. max_conf_idx = np.argmax(pred_list[:, 4])
  125. max_conf_pred = pred_list[max_conf_idx]
  126. pred_list = np.delete(pred_list, max_conf_idx, axis=0)
  127. color = color_dict[class_labels[class_id]]
  128. draw_box_and_label(max_conf_pred, color)
  129. object_count[class_labels[class_id]] += 1
  130. iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4])
  131. low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0]
  132. pred_list = pred_list[low_iou_indices]
  133. for low_conf_pred in pred_list:
  134. draw_box_and_label(low_conf_pred, color)
  135. print(f"Image {img_idx + 1}:")
  136. print("Objects detected:")
  137. for obj, count in object_count.items():
  138. print(f"- {obj}: {count}")
  139. cv2.imwrite(output_img_path, orig_img)
  140. print(f'saved detections at {output_img_path}')
  141. # utility functions for forward pass.
  142. def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
  143. lt, rb = distance.chunk(2, dim)
  144. x1y1 = anchor_points - lt
  145. x2y2 = anchor_points + rb
  146. if xywh:
  147. c_xy = (x1y1 + x2y2) / 2
  148. wh = x2y2 - x1y1
  149. return c_xy.cat(wh, dim=1)
  150. return x1y1.cat(x2y2, dim=1)
  151. def make_anchors(feats, strides, grid_cell_offset=0.5):
  152. anchor_points, stride_tensor = [], []
  153. assert feats is not None
  154. for i, stride in enumerate(strides):
  155. _, _, h, w = feats[i].shape
  156. sx = Tensor.arange(w) + grid_cell_offset
  157. sy = Tensor.arange(h) + grid_cell_offset
  158. # this is np.meshgrid but in tinygrad
  159. sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1)
  160. sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1)
  161. anchor_points.append(Tensor.stack(sx, sy, dim=-1).reshape(-1, 2))
  162. stride_tensor.append(Tensor.full((h * w), stride))
  163. anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
  164. stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
  165. return anchor_points, stride_tensor
  166. # this function is from the original implementation
  167. def autopad(k, p=None, d=1): # kernel, padding, dilation
  168. if d > 1:
  169. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  170. if p is None:
  171. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  172. return p
  173. def clip_boxes(boxes, shape):
  174. boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
  175. boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
  176. return boxes
  177. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
  178. gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
  179. pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
  180. boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
  181. boxes_np[..., [0, 2]] -= pad[0]
  182. boxes_np[..., [1, 3]] -= pad[1]
  183. boxes_np[..., :4] /= gain
  184. boxes_np = clip_boxes(boxes_np, img0_shape)
  185. return boxes_np
  186. def xywh2xyxy(x):
  187. xy = x[..., :2] # center x, y
  188. wh = x[..., 2:4] # width, height
  189. xy1 = xy - wh / 2 # top left x, y
  190. xy2 = xy + wh / 2 # bottom right x, y
  191. result = np.concatenate((xy1, xy2), axis=-1)
  192. return Tensor(result) if isinstance(x, Tensor) else result
  193. def get_variant_multiples(variant):
  194. return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
  195. def label_predictions(all_predictions):
  196. class_index_count = defaultdict(int)
  197. for predictions in all_predictions:
  198. predictions = np.array(predictions)
  199. for pred_np in predictions:
  200. class_id = int(pred_np[-1])
  201. class_index_count[class_id] += 1
  202. return dict(class_index_count)
  203. #this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
  204. class Upsample:
  205. def __init__(self, scale_factor:int, mode: str = "nearest") -> None:
  206. assert mode == "nearest" # only mode supported for now
  207. self.mode = mode
  208. self.scale_factor = scale_factor
  209. def __call__(self, x: Tensor) -> Tensor:
  210. assert len(x.shape) > 2 and len(x.shape) <= 5
  211. (b, c), _lens = x.shape[:2], len(x.shape[2:])
  212. tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens)
  213. return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
  214. class Conv_Block:
  215. def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None):
  216. self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation)
  217. self.bn = BatchNorm2d(c2, eps=0.001)
  218. def __call__(self, x):
  219. return self.bn(self.conv(x)).silu()
  220. class Bottleneck:
  221. def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5):
  222. c_ = int(c2 * channel_factor)
  223. self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
  224. self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g)
  225. self.residual = c1 == c2 and shortcut
  226. def __call__(self, x):
  227. return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
  228. class C2f:
  229. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  230. self.c = int(c2 * e)
  231. self.cv1 = Conv_Block(c1, 2 * self.c, 1,)
  232. self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
  233. self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)]
  234. def __call__(self, x):
  235. y= list(self.cv1(x).chunk(2, 1))
  236. y.extend(m(y[-1]) for m in self.bottleneck)
  237. z = y[0]
  238. for i in y[1:]: z = z.cat(i, dim=1)
  239. return self.cv2(z)
  240. class SPPF:
  241. def __init__(self, c1, c2, k=5):
  242. c_ = c1 // 2 # hidden channels
  243. self.cv1 = Conv_Block(c1, c_, 1, 1, padding=None)
  244. self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
  245. # TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
  246. self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1)
  247. def __call__(self, x):
  248. x = self.cv1(x)
  249. x2 = self.maxpool(x)
  250. x3 = self.maxpool(x2)
  251. x4 = self.maxpool(x3)
  252. return self.cv2(x.cat(x2, x3, x4, dim=1))
  253. class DFL:
  254. def __init__(self, c1=16):
  255. self.conv = Conv2d(c1, 1, 1, bias=False)
  256. x = Tensor.arange(c1)
  257. self.conv.weight.replace(x.reshape(1, c1, 1, 1))
  258. self.c1 = c1
  259. def __call__(self, x):
  260. b, c, a = x.shape # batch, channels, anchors
  261. return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a)
  262. #backbone
  263. class Darknet:
  264. def __init__(self, w, r, d):
  265. self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)]
  266. self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)]
  267. self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)]
  268. self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)]
  269. self.b5 = [SPPF(int(512*w*r), int(512*w*r), 5)]
  270. def return_modules(self):
  271. return [*self.b1, *self.b2, *self.b3, *self.b4, *self.b5]
  272. def __call__(self, x):
  273. x1 = x.sequential(self.b1)
  274. x2 = x1.sequential(self.b2)
  275. x3 = x2.sequential(self.b3)
  276. x4 = x3.sequential(self.b4)
  277. x5 = x4.sequential(self.b5)
  278. return (x2, x3, x5)
  279. #yolo fpn (neck)
  280. class Yolov8NECK:
  281. def __init__(self, w, r, d): #width_multiple, ratio_multiple, depth_multiple
  282. self.up = Upsample(2, mode='nearest')
  283. self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False)
  284. self.n2 = C2f(c1=int(768*w), c2=int(256*w), n=round(3*d), shortcut=False)
  285. self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1)
  286. self.n4 = C2f(c1=int(768*w), c2=int(512*w), n=round(3*d), shortcut=False)
  287. self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1)
  288. self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False)
  289. def return_modules(self):
  290. return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
  291. def __call__(self, p3, p4, p5):
  292. x = self.n1(self.up(p5).cat(p4, dim=1))
  293. head_1 = self.n2(self.up(x).cat(p3, dim=1))
  294. head_2 = self.n4(self.n3(head_1).cat(x, dim=1))
  295. head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
  296. return [head_1, head_2, head_3]
  297. #task specific head.
  298. class DetectionHead:
  299. def __init__(self, nc=80, filters=()):
  300. self.ch = 16
  301. self.nc = nc # number of classes
  302. self.nl = len(filters)
  303. self.no = nc + self.ch * 4 #
  304. self.stride = [8, 16, 32]
  305. c1 = max(filters[0], self.nc)
  306. c2 = max((filters[0] // 4, self.ch * 4))
  307. self.dfl = DFL(self.ch)
  308. self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters]
  309. self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters]
  310. def __call__(self, x):
  311. for i in range(self.nl):
  312. x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1))
  313. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  314. y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
  315. x_cat = y[0].cat(y[1], y[2], dim=2)
  316. box, cls = x_cat[:, :self.ch * 4], x_cat[:, self.ch * 4:]
  317. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  318. z = dbox.cat(cls.sigmoid(), dim=1)
  319. return z
  320. class YOLOv8:
  321. def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple
  322. self.net = Darknet(w, r, d)
  323. self.fpn = Yolov8NECK(w, r, d)
  324. self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r)))
  325. def __call__(self, x):
  326. x = self.net(x)
  327. x = self.fpn(*x)
  328. return self.head(x)
  329. def return_all_trainable_modules(self):
  330. backbone_modules = [*range(10)]
  331. yolov8neck_modules = [12, 15, 16, 18, 19, 21]
  332. yolov8_head_weights = [(22, self.head)]
  333. return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
  334. if __name__ == '__main__':
  335. # usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
  336. if len(sys.argv) < 2:
  337. print("Error: Image URL or path not provided.")
  338. sys.exit(1)
  339. img_path = sys.argv[1]
  340. yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
  341. print(f'running inference for YOLO version {yolo_variant}')
  342. output_folder_path = Path('./outputs_yolov8')
  343. output_folder_path.mkdir(parents=True, exist_ok=True)
  344. #absolute image path or URL
  345. image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
  346. image = [cv2.imdecode(image_location[0], 1)]
  347. out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix or '.png'}").as_posix()]
  348. if not isinstance(image[0], np.ndarray):
  349. print('Error in image loading. Check your image file.')
  350. sys.exit(1)
  351. pre_processed_image = preprocess(image)
  352. # Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/models/v8/yolov8.yaml
  353. depth, width, ratio = get_variant_multiples(yolo_variant)
  354. yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
  355. state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
  356. load_state_dict(yolo_infer, state_dict)
  357. st = time.time()
  358. predictions = yolo_infer(pre_processed_image)
  359. print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
  360. post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
  361. #v8 and v3 have same 80 class names for Object Detection
  362. class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
  363. draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
  364. # TODO for later:
  365. # 1. Fix SPPF minor difference due to maxpool
  366. # 2. AST exp overflow warning while on cpu
  367. # 3. Make NMS faster
  368. # 4. Add video inference and webcam support