mask_rcnn.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271
  1. import re
  2. import math
  3. import os
  4. import numpy as np
  5. from pathlib import Path
  6. from tinygrad import nn, Tensor, dtypes
  7. from tinygrad.tensor import _to_np_dtype
  8. from tinygrad.helpers import get_child, fetch
  9. from tinygrad.nn.state import torch_load
  10. from extra.models.resnet import ResNet
  11. from extra.models.retinanet import nms as _box_nms
  12. USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
  13. def rint(tensor):
  14. x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
  15. return (x<0).where(x.floor(), x.ceil())
  16. def nearest_interpolate(tensor, scale_factor):
  17. bs, c, py, px = tensor.shape
  18. return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
  19. def meshgrid(x, y):
  20. grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
  21. grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
  22. return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
  23. def topk(input_, k, dim=-1, largest=True, sorted=False):
  24. k = min(k, input_.shape[dim]-1)
  25. input_ = input_.numpy()
  26. if largest: input_ *= -1
  27. ind = np.argpartition(input_, k, axis=dim)
  28. if largest: input_ *= -1
  29. ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
  30. input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
  31. if not sorted: return Tensor(input_), ind
  32. if largest: input_ *= -1
  33. ind_part = np.argsort(input_, axis=dim)
  34. ind = np.take_along_axis(ind, ind_part, axis=dim)
  35. if largest: input_ *= -1
  36. val = np.take_along_axis(input_, ind_part, axis=dim)
  37. return Tensor(val), ind
  38. # This is very slow for large arrays, or indices
  39. def _gather(array, indices):
  40. indices = indices.float().to(array.device)
  41. reshape_arg = [1]*array.ndim + [array.shape[-1]]
  42. return Tensor.where(
  43. indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
  44. array, 0,
  45. ).sum(indices.ndim)
  46. # TODO: replace npgather with a faster gather using tinygrad only
  47. # NOTE: this blocks the gradient
  48. def npgather(array,indices):
  49. if isinstance(array, Tensor): array = array.numpy()
  50. if isinstance(indices, Tensor): indices = indices.numpy()
  51. if isinstance(indices, list): indices = np.asarray(indices)
  52. return Tensor(array[indices.astype(int)])
  53. def get_strides(shape):
  54. prod = [1]
  55. for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
  56. # something about ints is broken with gpu, cuda
  57. return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0)
  58. # with keys as integer array for all axes
  59. def tensor_getitem(tensor, *keys):
  60. # something about ints is broken with gpu, cuda
  61. flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
  62. strides = get_strides(tensor.shape)
  63. idxs = (flat_keys * strides).sum(1)
  64. gatherer = npgather if USE_NP_GATHER else _gather
  65. return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)
  66. # for gather with indicies only on axis=0
  67. def tensor_gather(tensor, indices):
  68. if not isinstance(indices, Tensor):
  69. indices = Tensor(indices, requires_grad=False)
  70. if len(tensor.shape) > 2:
  71. rem_shape = list(tensor.shape)[1:]
  72. tensor = tensor.reshape(tensor.shape[0], -1)
  73. else:
  74. rem_shape = None
  75. if len(tensor.shape) > 1:
  76. tensor = tensor.T
  77. repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
  78. indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
  79. ret = _gather(tensor, indices)
  80. if rem_shape:
  81. ret = ret.reshape([indices.shape[0]] + rem_shape)
  82. else:
  83. ret = _gather(tensor, indices)
  84. del indices
  85. return ret
  86. class LastLevelMaxPool:
  87. def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
  88. # transpose
  89. FLIP_LEFT_RIGHT = 0
  90. FLIP_TOP_BOTTOM = 1
  91. def permute_and_flatten(layer:Tensor, N, A, C, H, W):
  92. layer = layer.reshape(N, -1, C, H, W)
  93. layer = layer.permute(0, 3, 4, 1, 2)
  94. layer = layer.reshape(N, -1, C)
  95. return layer
  96. class BoxList:
  97. def __init__(self, bbox, image_size, mode="xyxy"):
  98. if not isinstance(bbox, Tensor):
  99. bbox = Tensor(bbox)
  100. if bbox.ndim != 2:
  101. raise ValueError(
  102. "bbox should have 2 dimensions, got {}".format(bbox.ndim)
  103. )
  104. if bbox.shape[-1] != 4:
  105. raise ValueError(
  106. "last dimenion of bbox should have a "
  107. "size of 4, got {}".format(bbox.shape[-1])
  108. )
  109. if mode not in ("xyxy", "xywh"):
  110. raise ValueError("mode should be 'xyxy' or 'xywh'")
  111. self.bbox = bbox
  112. self.size = image_size # (image_width, image_height)
  113. self.mode = mode
  114. self.extra_fields = {}
  115. def __repr__(self):
  116. s = self.__class__.__name__ + "("
  117. s += "num_boxes={}, ".format(len(self))
  118. s += "image_width={}, ".format(self.size[0])
  119. s += "image_height={}, ".format(self.size[1])
  120. s += "mode={})".format(self.mode)
  121. return s
  122. def area(self):
  123. box = self.bbox
  124. if self.mode == "xyxy":
  125. TO_REMOVE = 1
  126. area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
  127. elif self.mode == "xywh":
  128. area = box[:, 2] * box[:, 3]
  129. return area
  130. def add_field(self, field, field_data):
  131. self.extra_fields[field] = field_data
  132. def get_field(self, field):
  133. return self.extra_fields[field]
  134. def has_field(self, field):
  135. return field in self.extra_fields
  136. def fields(self):
  137. return list(self.extra_fields.keys())
  138. def _copy_extra_fields(self, bbox):
  139. for k, v in bbox.extra_fields.items():
  140. self.extra_fields[k] = v
  141. def convert(self, mode):
  142. if mode == self.mode:
  143. return self
  144. xmin, ymin, xmax, ymax = self._split_into_xyxy()
  145. if mode == "xyxy":
  146. bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
  147. bbox = BoxList(bbox, self.size, mode=mode)
  148. else:
  149. TO_REMOVE = 1
  150. bbox = Tensor.cat(
  151. *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
  152. )
  153. bbox = BoxList(bbox, self.size, mode=mode)
  154. bbox._copy_extra_fields(self)
  155. return bbox
  156. def _split_into_xyxy(self):
  157. if self.mode == "xyxy":
  158. xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
  159. return xmin, ymin, xmax, ymax
  160. if self.mode == "xywh":
  161. TO_REMOVE = 1
  162. xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
  163. return (
  164. xmin,
  165. ymin,
  166. xmin + (w - TO_REMOVE).clamp(min=0),
  167. ymin + (h - TO_REMOVE).clamp(min=0),
  168. )
  169. def resize(self, size, *args, **kwargs):
  170. ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
  171. if ratios[0] == ratios[1]:
  172. ratio = ratios[0]
  173. scaled_box = self.bbox * ratio
  174. bbox = BoxList(scaled_box, size, mode=self.mode)
  175. for k, v in self.extra_fields.items():
  176. if not isinstance(v, Tensor):
  177. v = v.resize(size, *args, **kwargs)
  178. bbox.add_field(k, v)
  179. return bbox
  180. ratio_width, ratio_height = ratios
  181. xmin, ymin, xmax, ymax = self._split_into_xyxy()
  182. scaled_xmin = xmin * ratio_width
  183. scaled_xmax = xmax * ratio_width
  184. scaled_ymin = ymin * ratio_height
  185. scaled_ymax = ymax * ratio_height
  186. scaled_box = Tensor.cat(
  187. *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
  188. )
  189. bbox = BoxList(scaled_box, size, mode="xyxy")
  190. for k, v in self.extra_fields.items():
  191. if not isinstance(v, Tensor):
  192. v = v.resize(size, *args, **kwargs)
  193. bbox.add_field(k, v)
  194. return bbox.convert(self.mode)
  195. def transpose(self, method):
  196. image_width, image_height = self.size
  197. xmin, ymin, xmax, ymax = self._split_into_xyxy()
  198. if method == FLIP_LEFT_RIGHT:
  199. TO_REMOVE = 1
  200. transposed_xmin = image_width - xmax - TO_REMOVE
  201. transposed_xmax = image_width - xmin - TO_REMOVE
  202. transposed_ymin = ymin
  203. transposed_ymax = ymax
  204. elif method == FLIP_TOP_BOTTOM:
  205. transposed_xmin = xmin
  206. transposed_xmax = xmax
  207. transposed_ymin = image_height - ymax
  208. transposed_ymax = image_height - ymin
  209. transposed_boxes = Tensor.cat(
  210. *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
  211. )
  212. bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
  213. for k, v in self.extra_fields.items():
  214. if not isinstance(v, Tensor):
  215. v = v.transpose(method)
  216. bbox.add_field(k, v)
  217. return bbox.convert(self.mode)
  218. def clip_to_image(self, remove_empty=True):
  219. TO_REMOVE = 1
  220. bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
  221. bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
  222. bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
  223. bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
  224. self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
  225. if remove_empty:
  226. box = self.bbox
  227. keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
  228. return self[keep]
  229. return self
  230. def __getitem__(self, item):
  231. if isinstance(item, list):
  232. if len(item) == 0:
  233. return []
  234. if sum(item) == len(item) and isinstance(item[0], bool):
  235. return self
  236. bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
  237. for k, v in self.extra_fields.items():
  238. bbox.add_field(k, tensor_gather(v, item))
  239. return bbox
  240. def __len__(self):
  241. return self.bbox.shape[0]
  242. def cat_boxlist(bboxes):
  243. size = bboxes[0].size
  244. mode = bboxes[0].mode
  245. fields = set(bboxes[0].fields())
  246. cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]
  247. if len(cat_box_list) > 0:
  248. cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
  249. else:
  250. cat_boxes = BoxList(bboxes[0].bbox, size, mode)
  251. for field in fields:
  252. cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
  253. if len(cat_box_list) > 0:
  254. data = Tensor.cat(*cat_field_list, dim=0)
  255. else:
  256. data = bboxes[0].get_field(field)
  257. cat_boxes.add_field(field, data)
  258. return cat_boxes
  259. class FPN:
  260. def __init__(self, in_channels_list, out_channels):
  261. self.inner_blocks, self.layer_blocks = [], []
  262. for in_channels in in_channels_list:
  263. self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
  264. self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
  265. self.top_block = LastLevelMaxPool()
  266. def __call__(self, x: Tensor):
  267. last_inner = self.inner_blocks[-1](x[-1])
  268. results = []
  269. results.append(self.layer_blocks[-1](last_inner))
  270. for feature, inner_block, layer_block in zip(
  271. x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
  272. ):
  273. if not inner_block:
  274. continue
  275. inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
  276. inner_lateral = inner_block(feature)
  277. last_inner = inner_lateral + inner_top_down
  278. layer_result = layer_block(last_inner)
  279. results.insert(0, layer_result)
  280. last_results = self.top_block(results[-1])
  281. results.extend(last_results)
  282. return tuple(results)
  283. class ResNetFPN:
  284. def __init__(self, resnet, out_channels=256):
  285. self.out_channels = out_channels
  286. self.body = resnet
  287. in_channels_stage2 = 256
  288. in_channels_list = [
  289. in_channels_stage2,
  290. in_channels_stage2 * 2,
  291. in_channels_stage2 * 4,
  292. in_channels_stage2 * 8,
  293. ]
  294. self.fpn = FPN(in_channels_list, out_channels)
  295. def __call__(self, x):
  296. x = self.body(x)
  297. return self.fpn(x)
  298. class AnchorGenerator:
  299. def __init__(
  300. self,
  301. sizes=(32, 64, 128, 256, 512),
  302. aspect_ratios=(0.5, 1.0, 2.0),
  303. anchor_strides=(4, 8, 16, 32, 64),
  304. straddle_thresh=0,
  305. ):
  306. if len(anchor_strides) == 1:
  307. anchor_stride = anchor_strides[0]
  308. cell_anchors = [
  309. generate_anchors(anchor_stride, sizes, aspect_ratios)
  310. ]
  311. else:
  312. if len(anchor_strides) != len(sizes):
  313. raise RuntimeError("FPN should have #anchor_strides == #sizes")
  314. cell_anchors = [
  315. generate_anchors(
  316. anchor_stride,
  317. size if isinstance(size, (tuple, list)) else (size,),
  318. aspect_ratios
  319. )
  320. for anchor_stride, size in zip(anchor_strides, sizes)
  321. ]
  322. self.strides = anchor_strides
  323. self.cell_anchors = cell_anchors
  324. self.straddle_thresh = straddle_thresh
  325. def num_anchors_per_location(self):
  326. return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]
  327. def grid_anchors(self, grid_sizes):
  328. anchors = []
  329. for size, stride, base_anchors in zip(
  330. grid_sizes, self.strides, self.cell_anchors
  331. ):
  332. grid_height, grid_width = size
  333. device = base_anchors.device
  334. shifts_x = Tensor.arange(
  335. start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
  336. )
  337. shifts_y = Tensor.arange(
  338. start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
  339. )
  340. shift_y, shift_x = meshgrid(shifts_y, shifts_x)
  341. shift_x = shift_x.reshape(-1)
  342. shift_y = shift_y.reshape(-1)
  343. shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)
  344. anchors.append(
  345. (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
  346. )
  347. return anchors
  348. def add_visibility_to(self, boxlist):
  349. image_width, image_height = boxlist.size
  350. anchors = boxlist.bbox
  351. if self.straddle_thresh >= 0:
  352. inds_inside = (
  353. (anchors[:, 0] >= -self.straddle_thresh)
  354. * (anchors[:, 1] >= -self.straddle_thresh)
  355. * (anchors[:, 2] < image_width + self.straddle_thresh)
  356. * (anchors[:, 3] < image_height + self.straddle_thresh)
  357. )
  358. else:
  359. device = anchors.device
  360. inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
  361. boxlist.add_field("visibility", inds_inside)
  362. def __call__(self, image_list, feature_maps):
  363. grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
  364. anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
  365. anchors = []
  366. for (image_height, image_width) in image_list.image_sizes:
  367. anchors_in_image = []
  368. for anchors_per_feature_map in anchors_over_all_feature_maps:
  369. boxlist = BoxList(
  370. anchors_per_feature_map, (image_width, image_height), mode="xyxy"
  371. )
  372. self.add_visibility_to(boxlist)
  373. anchors_in_image.append(boxlist)
  374. anchors.append(anchors_in_image)
  375. return anchors
  376. def generate_anchors(
  377. stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
  378. ):
  379. return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
  380. def _generate_anchors(base_size, scales, aspect_ratios):
  381. anchor = Tensor([1, 1, base_size, base_size]) - 1
  382. anchors = _ratio_enum(anchor, aspect_ratios)
  383. anchors = Tensor.cat(
  384. *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
  385. )
  386. return anchors
  387. def _whctrs(anchor):
  388. w = anchor[2] - anchor[0] + 1
  389. h = anchor[3] - anchor[1] + 1
  390. x_ctr = anchor[0] + 0.5 * (w - 1)
  391. y_ctr = anchor[1] + 0.5 * (h - 1)
  392. return w, h, x_ctr, y_ctr
  393. def _mkanchors(ws, hs, x_ctr, y_ctr):
  394. ws = ws[:, None]
  395. hs = hs[:, None]
  396. anchors = Tensor.cat(*(
  397. x_ctr - 0.5 * (ws - 1),
  398. y_ctr - 0.5 * (hs - 1),
  399. x_ctr + 0.5 * (ws - 1),
  400. y_ctr + 0.5 * (hs - 1),
  401. ), dim=1)
  402. return anchors
  403. def _ratio_enum(anchor, ratios):
  404. w, h, x_ctr, y_ctr = _whctrs(anchor)
  405. size = w * h
  406. size_ratios = size / ratios
  407. ws = rint(Tensor.sqrt(size_ratios))
  408. hs = rint(ws * ratios)
  409. anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
  410. return anchors
  411. def _scale_enum(anchor, scales):
  412. w, h, x_ctr, y_ctr = _whctrs(anchor)
  413. ws = w * scales
  414. hs = h * scales
  415. anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
  416. return anchors
  417. class RPNHead:
  418. def __init__(self, in_channels, num_anchors):
  419. self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
  420. self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
  421. self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)
  422. def __call__(self, x):
  423. logits = []
  424. bbox_reg = []
  425. for feature in x:
  426. t = Tensor.relu(self.conv(feature))
  427. logits.append(self.cls_logits(t))
  428. bbox_reg.append(self.bbox_pred(t))
  429. return logits, bbox_reg
  430. class BoxCoder(object):
  431. def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
  432. self.weights = weights
  433. self.bbox_xform_clip = bbox_xform_clip
  434. def encode(self, reference_boxes, proposals):
  435. TO_REMOVE = 1 # TODO remove
  436. ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
  437. ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
  438. ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
  439. ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
  440. gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
  441. gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
  442. gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
  443. gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
  444. wx, wy, ww, wh = self.weights
  445. targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
  446. targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
  447. targets_dw = ww * Tensor.log(gt_widths / ex_widths)
  448. targets_dh = wh * Tensor.log(gt_heights / ex_heights)
  449. targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
  450. return targets
  451. def decode(self, rel_codes, boxes):
  452. boxes = boxes.cast(rel_codes.dtype)
  453. rel_codes = rel_codes
  454. TO_REMOVE = 1 # TODO remove
  455. widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
  456. heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
  457. ctr_x = boxes[:, 0] + 0.5 * widths
  458. ctr_y = boxes[:, 1] + 0.5 * heights
  459. wx, wy, ww, wh = self.weights
  460. dx = rel_codes[:, 0::4] / wx
  461. dy = rel_codes[:, 1::4] / wy
  462. dw = rel_codes[:, 2::4] / ww
  463. dh = rel_codes[:, 3::4] / wh
  464. # Prevent sending too large values into Tensor.exp()
  465. dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
  466. dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)
  467. pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
  468. pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
  469. pred_w = dw.exp() * widths[:, None]
  470. pred_h = dh.exp() * heights[:, None]
  471. x = pred_ctr_x - 0.5 * pred_w
  472. y = pred_ctr_y - 0.5 * pred_h
  473. w = pred_ctr_x + 0.5 * pred_w - 1
  474. h = pred_ctr_y + 0.5 * pred_h - 1
  475. pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
  476. return pred_boxes
  477. def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
  478. if nms_thresh <= 0:
  479. return boxlist
  480. mode = boxlist.mode
  481. boxlist = boxlist.convert("xyxy")
  482. boxes = boxlist.bbox
  483. score = boxlist.get_field(score_field)
  484. keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
  485. if max_proposals > 0:
  486. keep = keep[:max_proposals]
  487. boxlist = boxlist[keep]
  488. return boxlist.convert(mode)
  489. def remove_small_boxes(boxlist, min_size):
  490. xywh_boxes = boxlist.convert("xywh").bbox
  491. _, _, ws, hs = xywh_boxes.chunk(4, dim=1)
  492. keep = ((
  493. (ws >= min_size) * (hs >= min_size)
  494. ) > 0).reshape(-1)
  495. if keep.sum().numpy() == len(boxlist):
  496. return boxlist
  497. else:
  498. keep = keep.numpy().nonzero()[0]
  499. return boxlist[keep]
  500. class RPNPostProcessor:
  501. # Not used in Loss calculation
  502. def __init__(
  503. self,
  504. pre_nms_top_n,
  505. post_nms_top_n,
  506. nms_thresh,
  507. min_size,
  508. box_coder=None,
  509. fpn_post_nms_top_n=None,
  510. ):
  511. self.pre_nms_top_n = pre_nms_top_n
  512. self.post_nms_top_n = post_nms_top_n
  513. self.nms_thresh = nms_thresh
  514. self.min_size = min_size
  515. if box_coder is None:
  516. box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  517. self.box_coder = box_coder
  518. if fpn_post_nms_top_n is None:
  519. fpn_post_nms_top_n = post_nms_top_n
  520. self.fpn_post_nms_top_n = fpn_post_nms_top_n
  521. def forward_for_single_feature_map(self, anchors, objectness, box_regression):
  522. device = objectness.device
  523. N, A, H, W = objectness.shape
  524. objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
  525. objectness = objectness.sigmoid()
  526. box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
  527. num_anchors = A * H * W
  528. pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
  529. objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
  530. concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
  531. image_shapes = [box.size for box in anchors]
  532. box_regression_list = []
  533. concat_anchors_list = []
  534. for batch_idx in range(N):
  535. box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
  536. concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
  537. box_regression = Tensor.stack(*box_regression_list)
  538. concat_anchors = Tensor.stack(*concat_anchors_list)
  539. proposals = self.box_coder.decode(
  540. box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
  541. )
  542. proposals = proposals.reshape(N, -1, 4)
  543. result = []
  544. for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
  545. boxlist = BoxList(proposal, im_shape, mode="xyxy")
  546. boxlist.add_field("objectness", score)
  547. boxlist = boxlist.clip_to_image(remove_empty=False)
  548. boxlist = remove_small_boxes(boxlist, self.min_size)
  549. boxlist = boxlist_nms(
  550. boxlist,
  551. self.nms_thresh,
  552. max_proposals=self.post_nms_top_n,
  553. score_field="objectness",
  554. )
  555. result.append(boxlist)
  556. return result
  557. def __call__(self, anchors, objectness, box_regression):
  558. sampled_boxes = []
  559. num_levels = len(objectness)
  560. anchors = list(zip(*anchors))
  561. for a, o, b in zip(anchors, objectness, box_regression):
  562. sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
  563. boxlists = list(zip(*sampled_boxes))
  564. boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
  565. if num_levels > 1:
  566. boxlists = self.select_over_all_levels(boxlists)
  567. return boxlists
  568. def select_over_all_levels(self, boxlists):
  569. num_images = len(boxlists)
  570. for i in range(num_images):
  571. objectness = boxlists[i].get_field("objectness")
  572. post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
  573. _, inds_sorted = topk(objectness,
  574. post_nms_top_n, dim=0, sorted=False
  575. )
  576. boxlists[i] = boxlists[i][inds_sorted]
  577. return boxlists
  578. class RPN:
  579. def __init__(self, in_channels):
  580. self.anchor_generator = AnchorGenerator()
  581. in_channels = 256
  582. head = RPNHead(
  583. in_channels, self.anchor_generator.num_anchors_per_location()[0]
  584. )
  585. rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  586. box_selector_test = RPNPostProcessor(
  587. pre_nms_top_n=1000,
  588. post_nms_top_n=1000,
  589. nms_thresh=0.7,
  590. min_size=0,
  591. box_coder=rpn_box_coder,
  592. fpn_post_nms_top_n=1000
  593. )
  594. self.head = head
  595. self.box_selector_test = box_selector_test
  596. def __call__(self, images, features, targets=None):
  597. objectness, rpn_box_regression = self.head(features)
  598. anchors = self.anchor_generator(images, features)
  599. boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
  600. return boxes, {}
  601. def make_conv3x3(
  602. in_channels,
  603. out_channels,
  604. dilation=1,
  605. stride=1,
  606. use_gn=False,
  607. ):
  608. conv = nn.Conv2d(
  609. in_channels,
  610. out_channels,
  611. kernel_size=3,
  612. stride=stride,
  613. padding=dilation,
  614. dilation=dilation,
  615. bias=False if use_gn else True
  616. )
  617. return conv
  618. class MaskRCNNFPNFeatureExtractor:
  619. def __init__(self):
  620. resolution = 14
  621. scales = (0.25, 0.125, 0.0625, 0.03125)
  622. sampling_ratio = 2
  623. pooler = Pooler(
  624. output_size=(resolution, resolution),
  625. scales=scales,
  626. sampling_ratio=sampling_ratio,
  627. )
  628. input_size = 256
  629. self.pooler = pooler
  630. use_gn = False
  631. layers = (256, 256, 256, 256)
  632. dilation = 1
  633. self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
  634. self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
  635. self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
  636. self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
  637. self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
  638. def __call__(self, x, proposals):
  639. x = self.pooler(x, proposals)
  640. for layer in self.blocks:
  641. if x is not None:
  642. x = Tensor.relu(layer(x))
  643. return x
  644. class MaskRCNNC4Predictor:
  645. def __init__(self):
  646. num_classes = 81
  647. dim_reduced = 256
  648. num_inputs = dim_reduced
  649. self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
  650. self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)
  651. def __call__(self, x):
  652. x = Tensor.relu(self.conv5_mask(x))
  653. return self.mask_fcn_logits(x)
  654. class FPN2MLPFeatureExtractor:
  655. def __init__(self, cfg):
  656. resolution = 7
  657. scales = (0.25, 0.125, 0.0625, 0.03125)
  658. sampling_ratio = 2
  659. pooler = Pooler(
  660. output_size=(resolution, resolution),
  661. scales=scales,
  662. sampling_ratio=sampling_ratio,
  663. )
  664. input_size = 256 * resolution ** 2
  665. representation_size = 1024
  666. self.pooler = pooler
  667. self.fc6 = nn.Linear(input_size, representation_size)
  668. self.fc7 = nn.Linear(representation_size, representation_size)
  669. def __call__(self, x, proposals):
  670. x = self.pooler(x, proposals)
  671. x = x.reshape(x.shape[0], -1)
  672. x = Tensor.relu(self.fc6(x))
  673. x = Tensor.relu(self.fc7(x))
  674. return x
  675. def _bilinear_interpolate(
  676. input, # [N, C, H, W]
  677. roi_batch_ind, # [K]
  678. y, # [K, PH, IY]
  679. x, # [K, PW, IX]
  680. ymask, # [K, IY]
  681. xmask, # [K, IX]
  682. ):
  683. _, channels, height, width = input.shape
  684. y = y.clip(min_=0.0, max_=float(height-1))
  685. x = x.clip(min_=0.0, max_=float(width-1))
  686. # Tensor.where doesnt work well with int32 data so cast to float32
  687. y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
  688. x_low = x.cast(dtypes.int32).contiguous().float().contiguous()
  689. y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
  690. y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)
  691. x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
  692. x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)
  693. ly = y - y_low
  694. lx = x - x_low
  695. hy = 1.0 - ly
  696. hx = 1.0 - lx
  697. def masked_index(
  698. y, # [K, PH, IY]
  699. x, # [K, PW, IX]
  700. ):
  701. if ymask is not None:
  702. assert xmask is not None
  703. y = Tensor.where(ymask[:, None, :], y, 0)
  704. x = Tensor.where(xmask[:, None, :], x, 0)
  705. key1 = roi_batch_ind[:, None, None, None, None, None]
  706. key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
  707. key3 = y[:, None, :, None, :, None]
  708. key4 = x[:, None, None, :, None, :]
  709. return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
  710. v1 = masked_index(y_low, x_low)
  711. v2 = masked_index(y_low, x_high)
  712. v3 = masked_index(y_high, x_low)
  713. v4 = masked_index(y_high, x_high)
  714. # all ws preemptively [K, C, PH, PW, IY, IX]
  715. def outer_prod(y, x):
  716. return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
  717. w1 = outer_prod(hy, hx)
  718. w2 = outer_prod(hy, lx)
  719. w3 = outer_prod(ly, hx)
  720. w4 = outer_prod(ly, lx)
  721. val = w1*v1 + w2*v2 + w3*v3 + w4*v4
  722. return val
  723. #https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
  724. def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  725. orig_dtype = input.dtype
  726. _, _, height, width = input.shape
  727. ph = Tensor.arange(pooled_height, device=input.device)
  728. pw = Tensor.arange(pooled_width, device=input.device)
  729. roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
  730. offset = 0.5 if aligned else 0.0
  731. roi_start_w = rois[:, 1] * spatial_scale - offset
  732. roi_start_h = rois[:, 2] * spatial_scale - offset
  733. roi_end_w = rois[:, 3] * spatial_scale - offset
  734. roi_end_h = rois[:, 4] * spatial_scale - offset
  735. roi_width = roi_end_w - roi_start_w
  736. roi_height = roi_end_h - roi_start_h
  737. if not aligned:
  738. roi_width = roi_width.maximum(1.0)
  739. roi_height = roi_height.maximum(1.0)
  740. bin_size_h = roi_height / pooled_height
  741. bin_size_w = roi_width / pooled_width
  742. exact_sampling = sampling_ratio > 0
  743. roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
  744. roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
  745. if exact_sampling:
  746. count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
  747. iy = Tensor.arange(roi_bin_grid_h, device=input.device)
  748. ix = Tensor.arange(roi_bin_grid_w, device=input.device)
  749. ymask = None
  750. xmask = None
  751. else:
  752. count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
  753. iy = Tensor.arange(height, device=input.device)
  754. ix = Tensor.arange(width, device=input.device)
  755. ymask = iy[None, :] < roi_bin_grid_h[:, None]
  756. xmask = ix[None, :] < roi_bin_grid_w[:, None]
  757. def from_K(t):
  758. return t[:, None, None]
  759. y = (
  760. from_K(roi_start_h)
  761. + ph[None, :, None] * from_K(bin_size_h)
  762. + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
  763. )
  764. x = (
  765. from_K(roi_start_w)
  766. + pw[None, :, None] * from_K(bin_size_w)
  767. + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
  768. )
  769. val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
  770. if not exact_sampling:
  771. val = ymask[:, None, None, None, :, None].where(val, 0)
  772. val = xmask[:, None, None, None, None, :].where(val, 0)
  773. output = val.sum((-1, -2))
  774. if isinstance(count, Tensor):
  775. output /= count[:, None, None, None]
  776. else:
  777. output /= count
  778. output = output.cast(orig_dtype)
  779. return output
  780. class ROIAlign:
  781. def __init__(self, output_size, spatial_scale, sampling_ratio):
  782. self.output_size = output_size
  783. self.spatial_scale = spatial_scale
  784. self.sampling_ratio = sampling_ratio
  785. def __call__(self, input, rois):
  786. output = _roi_align(
  787. input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
  788. )
  789. return output
  790. class LevelMapper:
  791. def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
  792. self.k_min = k_min
  793. self.k_max = k_max
  794. self.s0 = canonical_scale
  795. self.lvl0 = canonical_level
  796. self.eps = eps
  797. def __call__(self, boxlists):
  798. s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
  799. target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
  800. target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
  801. return target_lvls - self.k_min
  802. class Pooler:
  803. def __init__(self, output_size, scales, sampling_ratio):
  804. self.output_size = output_size
  805. self.scales = scales
  806. self.sampling_ratio = sampling_ratio
  807. poolers = []
  808. for scale in scales:
  809. poolers.append(
  810. ROIAlign(
  811. output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
  812. )
  813. )
  814. self.poolers = poolers
  815. self.output_size = output_size
  816. lvl_min = -math.log2(scales[0])
  817. lvl_max = -math.log2(scales[-1])
  818. self.map_levels = LevelMapper(lvl_min, lvl_max)
  819. def convert_to_roi_format(self, boxes):
  820. concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
  821. device, dtype = concat_boxes.device, concat_boxes.dtype
  822. ids = Tensor.cat(
  823. *[
  824. Tensor.full((len(b), 1), i, dtype=dtype, device=device)
  825. for i, b in enumerate(boxes)
  826. ],
  827. dim=0,
  828. )
  829. if concat_boxes.shape[0] != 0:
  830. rois = Tensor.cat(*[ids, concat_boxes], dim=1)
  831. return rois
  832. def __call__(self, x, boxes):
  833. num_levels = len(self.poolers)
  834. rois = self.convert_to_roi_format(boxes)
  835. if rois is not None:
  836. if num_levels == 1:
  837. return self.poolers[0](x[0], rois)
  838. levels = self.map_levels(boxes)
  839. results = []
  840. all_idxs = []
  841. for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
  842. # this is fine because no grad will flow through index
  843. idx_in_level = (levels.numpy() == level).nonzero()[0]
  844. if len(idx_in_level) > 0:
  845. rois_per_level = tensor_gather(rois, idx_in_level)
  846. pooler_output = pooler(per_level_feature, rois_per_level)
  847. all_idxs.extend(idx_in_level)
  848. results.append(pooler_output)
  849. return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
  850. class FPNPredictor:
  851. def __init__(self):
  852. num_classes = 81
  853. representation_size = 1024
  854. self.cls_score = nn.Linear(representation_size, num_classes)
  855. num_bbox_reg_classes = num_classes
  856. self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
  857. def __call__(self, x):
  858. scores = self.cls_score(x)
  859. bbox_deltas = self.bbox_pred(x)
  860. return scores, bbox_deltas
  861. class PostProcessor:
  862. # Not used in training
  863. def __init__(
  864. self,
  865. score_thresh=0.05,
  866. nms=0.5,
  867. detections_per_img=100,
  868. box_coder=None,
  869. cls_agnostic_bbox_reg=False
  870. ):
  871. self.score_thresh = score_thresh
  872. self.nms = nms
  873. self.detections_per_img = detections_per_img
  874. if box_coder is None:
  875. box_coder = BoxCoder(weights=(10., 10., 5., 5.))
  876. self.box_coder = box_coder
  877. self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
  878. def __call__(self, x, boxes):
  879. class_logits, box_regression = x
  880. class_prob = Tensor.softmax(class_logits, -1)
  881. image_shapes = [box.size for box in boxes]
  882. boxes_per_image = [len(box) for box in boxes]
  883. concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)
  884. if self.cls_agnostic_bbox_reg:
  885. box_regression = box_regression[:, -4:]
  886. proposals = self.box_coder.decode(
  887. box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
  888. )
  889. if self.cls_agnostic_bbox_reg:
  890. proposals = proposals.repeat([1, class_prob.shape[1]])
  891. num_classes = class_prob.shape[1]
  892. proposals = proposals.unsqueeze(0)
  893. class_prob = class_prob.unsqueeze(0)
  894. results = []
  895. for prob, boxes_per_img, image_shape in zip(
  896. class_prob, proposals, image_shapes
  897. ):
  898. boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
  899. boxlist = boxlist.clip_to_image(remove_empty=False)
  900. boxlist = self.filter_results(boxlist, num_classes)
  901. results.append(boxlist)
  902. return results
  903. def prepare_boxlist(self, boxes, scores, image_shape):
  904. boxes = boxes.reshape(-1, 4)
  905. scores = scores.reshape(-1)
  906. boxlist = BoxList(boxes, image_shape, mode="xyxy")
  907. boxlist.add_field("scores", scores)
  908. return boxlist
  909. def filter_results(self, boxlist, num_classes):
  910. boxes = boxlist.bbox.reshape(-1, num_classes * 4)
  911. scores = boxlist.get_field("scores").reshape(-1, num_classes)
  912. device = scores.device
  913. result = []
  914. scores = scores.numpy()
  915. boxes = boxes.numpy()
  916. inds_all = scores > self.score_thresh
  917. for j in range(1, num_classes):
  918. inds = inds_all[:, j].nonzero()[0]
  919. # This needs to be done in numpy because it can create empty arrays
  920. scores_j = scores[inds, j]
  921. boxes_j = boxes[inds, j * 4: (j + 1) * 4]
  922. boxes_j = Tensor(boxes_j)
  923. scores_j = Tensor(scores_j)
  924. boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
  925. boxlist_for_class.add_field("scores", scores_j)
  926. if len(boxlist_for_class):
  927. boxlist_for_class = boxlist_nms(
  928. boxlist_for_class, self.nms
  929. )
  930. num_labels = len(boxlist_for_class)
  931. boxlist_for_class.add_field(
  932. "labels", Tensor.full((num_labels,), j, device=device)
  933. )
  934. result.append(boxlist_for_class)
  935. result = cat_boxlist(result)
  936. number_of_detections = len(result)
  937. if number_of_detections > self.detections_per_img > 0:
  938. cls_scores = result.get_field("scores")
  939. image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
  940. image_thresh = image_thresh.numpy()[-1]
  941. keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
  942. result = result[keep]
  943. return result
  944. class RoIBoxHead:
  945. def __init__(self, in_channels):
  946. self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
  947. self.predictor = FPNPredictor()
  948. self.post_processor = PostProcessor(
  949. score_thresh=0.05,
  950. nms=0.5,
  951. detections_per_img=100,
  952. box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
  953. cls_agnostic_bbox_reg=False
  954. )
  955. def __call__(self, features, proposals, targets=None):
  956. x = self.feature_extractor(features, proposals)
  957. class_logits, box_regression = self.predictor(x)
  958. if not Tensor.training:
  959. result = self.post_processor((class_logits, box_regression), proposals)
  960. return x, result, {}
  961. class MaskPostProcessor:
  962. # Not used in loss calculation
  963. def __call__(self, x, boxes):
  964. mask_prob = x.sigmoid().numpy()
  965. num_masks = x.shape[0]
  966. labels = [bbox.get_field("labels") for bbox in boxes]
  967. labels = Tensor.cat(*labels).numpy().astype(np.int32)
  968. index = np.arange(num_masks)
  969. mask_prob = mask_prob[index, labels][:, None]
  970. boxes_per_image, cumsum = [], 0
  971. for box in boxes:
  972. cumsum += len(box)
  973. boxes_per_image.append(cumsum)
  974. # using numpy here as Tensor.chunk doesnt have custom chunk sizes
  975. mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
  976. results = []
  977. for prob, box in zip(mask_prob, boxes):
  978. bbox = BoxList(box.bbox, box.size, mode="xyxy")
  979. for field in box.fields():
  980. bbox.add_field(field, box.get_field(field))
  981. prob = Tensor(prob)
  982. bbox.add_field("mask", prob)
  983. results.append(bbox)
  984. return results
  985. class Mask:
  986. def __init__(self):
  987. self.feature_extractor = MaskRCNNFPNFeatureExtractor()
  988. self.predictor = MaskRCNNC4Predictor()
  989. self.post_processor = MaskPostProcessor()
  990. def __call__(self, features, proposals, targets=None):
  991. x = self.feature_extractor(features, proposals)
  992. if x:
  993. mask_logits = self.predictor(x)
  994. if not Tensor.training:
  995. result = self.post_processor(mask_logits, proposals)
  996. return x, result, {}
  997. return x, [], {}
  998. class RoIHeads:
  999. def __init__(self, in_channels):
  1000. self.box = RoIBoxHead(in_channels)
  1001. self.mask = Mask()
  1002. def __call__(self, features, proposals, targets=None):
  1003. x, detections, _ = self.box(features, proposals, targets)
  1004. x, detections, _ = self.mask(features, detections, targets)
  1005. return x, detections, {}
  1006. class ImageList(object):
  1007. def __init__(self, tensors, image_sizes):
  1008. self.tensors = tensors
  1009. self.image_sizes = image_sizes
  1010. def to(self, *args, **kwargs):
  1011. cast_tensor = self.tensors.to(*args, **kwargs)
  1012. return ImageList(cast_tensor, self.image_sizes)
  1013. def to_image_list(tensors, size_divisible=32):
  1014. # Preprocessing
  1015. if isinstance(tensors, Tensor) and size_divisible > 0:
  1016. tensors = [tensors]
  1017. if isinstance(tensors, ImageList):
  1018. return tensors
  1019. elif isinstance(tensors, Tensor):
  1020. # single tensor shape can be inferred
  1021. assert tensors.ndim == 4
  1022. image_sizes = [tensor.shape[-2:] for tensor in tensors]
  1023. return ImageList(tensors, image_sizes)
  1024. elif isinstance(tensors, (tuple, list)):
  1025. max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
  1026. if size_divisible > 0:
  1027. stride = size_divisible
  1028. max_size = list(max_size)
  1029. max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
  1030. max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
  1031. max_size = tuple(max_size)
  1032. batch_shape = (len(tensors),) + max_size
  1033. batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
  1034. for img, pad_img in zip(tensors, batched_imgs):
  1035. pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()
  1036. batched_imgs = Tensor(batched_imgs)
  1037. image_sizes = [im.shape[-2:] for im in tensors]
  1038. return ImageList(batched_imgs, image_sizes)
  1039. else:
  1040. raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
  1041. class MaskRCNN:
  1042. def __init__(self, backbone: ResNet):
  1043. self.backbone = ResNetFPN(backbone, out_channels=256)
  1044. self.rpn = RPN(self.backbone.out_channels)
  1045. self.roi_heads = RoIHeads(self.backbone.out_channels)
  1046. def load_from_pretrained(self):
  1047. fn = Path('./') / "weights/maskrcnn.pt"
  1048. fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
  1049. state_dict = torch_load(fn)['model']
  1050. loaded_keys = []
  1051. for k, v in state_dict.items():
  1052. if "module." in k:
  1053. k = k.replace("module.", "")
  1054. if "stem." in k:
  1055. k = k.replace("stem.", "")
  1056. if "fpn_inner" in k:
  1057. block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
  1058. k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
  1059. if "fpn_layer" in k:
  1060. block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
  1061. k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
  1062. loaded_keys.append(k)
  1063. get_child(self, k).assign(v.numpy()).realize()
  1064. return loaded_keys
  1065. def __call__(self, images):
  1066. images = to_image_list(images)
  1067. features = self.backbone(images.tensors)
  1068. proposals, _ = self.rpn(images, features)
  1069. x, result, _ = self.roi_heads(features, proposals)
  1070. return result
  1071. if __name__ == '__main__':
  1072. resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
  1073. model = MaskRCNN(backbone=resnet)
  1074. model.load_from_pretrained()