retinanet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import math
  2. from tinygrad.helpers import flatten, get_child
  3. import tinygrad.nn as nn
  4. from extra.models.resnet import ResNet
  5. import numpy as np
  6. def nms(boxes, scores, thresh=0.5):
  7. x1, y1, x2, y2 = np.rollaxis(boxes, 1)
  8. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  9. to_process, keep = scores.argsort()[::-1], []
  10. while to_process.size > 0:
  11. cur, to_process = to_process[0], to_process[1:]
  12. keep.append(cur)
  13. inter_x1 = np.maximum(x1[cur], x1[to_process])
  14. inter_y1 = np.maximum(y1[cur], y1[to_process])
  15. inter_x2 = np.minimum(x2[cur], x2[to_process])
  16. inter_y2 = np.minimum(y2[cur], y2[to_process])
  17. inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
  18. iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
  19. to_process = to_process[np.where(iou <= thresh)[0]]
  20. return keep
  21. def decode_bbox(offsets, anchors):
  22. dx, dy, dw, dh = np.rollaxis(offsets, 1)
  23. widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
  24. cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights
  25. pred_cx, pred_cy = dx * widths + cx, dy * heights + cy
  26. pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights
  27. pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h
  28. pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
  29. return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
  30. def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
  31. assert len(scales) == len(aspect_ratios) == len(grid_sizes)
  32. anchors = []
  33. for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
  34. s, ar = np.array(s), np.array(ar)
  35. h_ratios = np.sqrt(ar)
  36. w_ratios = 1 / h_ratios
  37. ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
  38. hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
  39. base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
  40. stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
  41. shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
  42. shifts_x = shifts_x.reshape(-1)
  43. shifts_y = shifts_y.reshape(-1)
  44. shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
  45. anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
  46. return anchors
  47. class RetinaNet:
  48. def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
  49. assert isinstance(backbone, ResNet)
  50. scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
  51. aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
  52. self.num_anchors, self.num_classes = num_anchors, num_classes
  53. assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
  54. self.backbone = ResNetFPN(backbone)
  55. self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
  56. self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
  57. def __call__(self, x):
  58. return self.forward(x)
  59. def forward(self, x):
  60. return self.head(self.backbone(x))
  61. def load_from_pretrained(self):
  62. model_urls = {
  63. (50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
  64. (50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
  65. }
  66. self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
  67. from torch.hub import load_state_dict_from_url
  68. state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
  69. state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
  70. for k, v in state_dict.items():
  71. obj = get_child(self, k)
  72. dat = v.detach().numpy()
  73. assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
  74. obj.assign(dat)
  75. # predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
  76. def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
  77. anchors = self.anchor_gen(input_size)
  78. grid_sizes = self.backbone.compute_grid_sizes(input_size)
  79. split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
  80. detections = []
  81. for i, predictions_per_image in enumerate(predictions):
  82. h, w = input_size if image_sizes is None else image_sizes[i]
  83. predictions_per_image = np.split(predictions_per_image, split_idx)
  84. offsets_per_image = [br[:, :4] for br in predictions_per_image]
  85. scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
  86. image_boxes, image_scores, image_labels = [], [], []
  87. for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
  88. # remove low scoring boxes
  89. scores_per_level = scores_per_level.flatten()
  90. keep_idxs = scores_per_level > score_thresh
  91. scores_per_level = scores_per_level[keep_idxs]
  92. # keep topk
  93. topk_idxs = np.where(keep_idxs)[0]
  94. num_topk = min(len(topk_idxs), topk_candidates)
  95. sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
  96. topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
  97. # bbox coords from offsets
  98. anchor_idxs = topk_idxs // self.num_classes
  99. labels_per_level = topk_idxs % self.num_classes
  100. boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
  101. # clip to image size
  102. clipped_x = boxes_per_level[:, 0::2].clip(0, w)
  103. clipped_y = boxes_per_level[:, 1::2].clip(0, h)
  104. boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
  105. image_boxes.append(boxes_per_level)
  106. image_scores.append(scores_per_level)
  107. image_labels.append(labels_per_level)
  108. image_boxes = np.concatenate(image_boxes)
  109. image_scores = np.concatenate(image_scores)
  110. image_labels = np.concatenate(image_labels)
  111. # nms for each class
  112. keep_mask = np.zeros_like(image_scores, dtype=bool)
  113. for class_id in np.unique(image_labels):
  114. curr_indices = np.where(image_labels == class_id)[0]
  115. curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
  116. keep_mask[curr_indices[curr_keep_indices]] = True
  117. keep = np.where(keep_mask)[0]
  118. keep = keep[image_scores[keep].argsort()[::-1]]
  119. # resize bboxes back to original size
  120. image_boxes = image_boxes[keep]
  121. if orig_image_sizes is not None:
  122. resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w
  123. resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
  124. image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
  125. # xywh format
  126. image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
  127. detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
  128. return detections
  129. class ClassificationHead:
  130. def __init__(self, in_channels, num_anchors, num_classes):
  131. self.num_classes = num_classes
  132. self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
  133. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
  134. def __call__(self, x):
  135. out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
  136. return out[0].cat(*out[1:], dim=1).sigmoid()
  137. class RegressionHead:
  138. def __init__(self, in_channels, num_anchors):
  139. self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
  140. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
  141. def __call__(self, x):
  142. out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
  143. return out[0].cat(*out[1:], dim=1)
  144. class RetinaHead:
  145. def __init__(self, in_channels, num_anchors, num_classes):
  146. self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
  147. self.regression_head = RegressionHead(in_channels, num_anchors)
  148. def __call__(self, x):
  149. pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
  150. out = pred_bbox.cat(pred_class, dim=-1)
  151. return out
  152. class ResNetFPN:
  153. def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
  154. self.out_channels = out_channels
  155. self.body = resnet
  156. in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
  157. self.fpn = FPN(in_channels_list, out_channels)
  158. # this is needed to decouple inference from postprocessing (anchors generation)
  159. def compute_grid_sizes(self, input_size):
  160. return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
  161. def __call__(self, x):
  162. out = self.body.bn1(self.body.conv1(x)).relu()
  163. out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
  164. out = out.sequential(self.body.layer1)
  165. p3 = out.sequential(self.body.layer2)
  166. p4 = p3.sequential(self.body.layer3)
  167. p5 = p4.sequential(self.body.layer4)
  168. return self.fpn([p3, p4, p5])
  169. class ExtraFPNBlock:
  170. def __init__(self, in_channels, out_channels):
  171. self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
  172. self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
  173. self.use_P5 = in_channels == out_channels
  174. def __call__(self, p, c):
  175. p5, c5 = p[-1], c[-1]
  176. x = p5 if self.use_P5 else c5
  177. p6 = self.p6(x)
  178. p7 = self.p7(p6.relu())
  179. p.extend([p6, p7])
  180. return p
  181. class FPN:
  182. def __init__(self, in_channels_list, out_channels, extra_blocks=None):
  183. self.inner_blocks, self.layer_blocks = [], []
  184. for in_channels in in_channels_list:
  185. self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
  186. self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
  187. self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
  188. def __call__(self, x):
  189. last_inner = self.inner_blocks[-1](x[-1])
  190. results = [self.layer_blocks[-1](last_inner)]
  191. for idx in range(len(x) - 2, -1, -1):
  192. inner_lateral = self.inner_blocks[idx](x[idx])
  193. # upsample to inner_lateral's shape
  194. (ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
  195. eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
  196. inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
  197. last_inner = inner_lateral + inner_top_down
  198. results.insert(0, self.layer_blocks[idx](last_inner))
  199. if self.extra_blocks is not None:
  200. results = self.extra_blocks(results, x)
  201. return results
  202. if __name__ == "__main__":
  203. from extra.models.resnet import ResNeXt50_32X4D
  204. backbone = ResNeXt50_32X4D()
  205. retina = RetinaNet(backbone)
  206. retina.load_from_pretrained()