mask_rcnn.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from extra.models.mask_rcnn import MaskRCNN
  2. from extra.models.resnet import ResNet
  3. from extra.models.mask_rcnn import BoxList
  4. from torch.nn import functional as F
  5. from torchvision import transforms as T
  6. from torchvision.transforms import functional as Ft
  7. import random
  8. from tinygrad.tensor import Tensor
  9. from PIL import Image
  10. import numpy as np
  11. import torch
  12. import argparse
  13. import cv2
  14. class Resize:
  15. def __init__(self, min_size, max_size):
  16. if not isinstance(min_size, (list, tuple)):
  17. min_size = (min_size,)
  18. self.min_size = min_size
  19. self.max_size = max_size
  20. # modified from torchvision to add support for max size
  21. def get_size(self, image_size):
  22. w, h = image_size
  23. size = random.choice(self.min_size)
  24. max_size = self.max_size
  25. if max_size is not None:
  26. min_original_size = float(min((w, h)))
  27. max_original_size = float(max((w, h)))
  28. if max_original_size / min_original_size * size > max_size:
  29. size = int(round(max_size * min_original_size / max_original_size))
  30. if (w <= h and w == size) or (h <= w and h == size):
  31. return (h, w)
  32. if w < h:
  33. ow = size
  34. oh = int(size * h / w)
  35. else:
  36. oh = size
  37. ow = int(size * w / h)
  38. return (oh, ow)
  39. def __call__(self, image):
  40. size = self.get_size(image.size)
  41. image = Ft.resize(image, size)
  42. return image
  43. class Normalize:
  44. def __init__(self, mean, std, to_bgr255=True):
  45. self.mean = mean
  46. self.std = std
  47. self.to_bgr255 = to_bgr255
  48. def __call__(self, image):
  49. if self.to_bgr255:
  50. image = image[[2, 1, 0]] * 255
  51. else:
  52. image = image[[0, 1, 2]] * 255
  53. image = Ft.normalize(image, mean=self.mean, std=self.std)
  54. return image
  55. transforms = lambda size_scale: T.Compose(
  56. [
  57. Resize(int(800*size_scale), int(1333*size_scale)),
  58. T.ToTensor(),
  59. Normalize(
  60. mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
  61. ),
  62. ]
  63. )
  64. def expand_boxes(boxes, scale):
  65. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  66. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  67. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  68. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  69. w_half *= scale
  70. h_half *= scale
  71. boxes_exp = torch.zeros_like(boxes)
  72. boxes_exp[:, 0] = x_c - w_half
  73. boxes_exp[:, 2] = x_c + w_half
  74. boxes_exp[:, 1] = y_c - h_half
  75. boxes_exp[:, 3] = y_c + h_half
  76. return boxes_exp
  77. def expand_masks(mask, padding):
  78. N = mask.shape[0]
  79. M = mask.shape[-1]
  80. pad2 = 2 * padding
  81. scale = float(M + pad2) / M
  82. padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
  83. padded_mask[:, :, padding:-padding, padding:-padding] = mask
  84. return padded_mask, scale
  85. def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
  86. # TODO: remove torch
  87. mask = torch.tensor(mask.numpy())
  88. box = torch.tensor(box.numpy())
  89. padded_mask, scale = expand_masks(mask[None], padding=padding)
  90. mask = padded_mask[0, 0]
  91. box = expand_boxes(box[None], scale)[0]
  92. box = box.to(dtype=torch.int32)
  93. TO_REMOVE = 1
  94. w = int(box[2] - box[0] + TO_REMOVE)
  95. h = int(box[3] - box[1] + TO_REMOVE)
  96. w = max(w, 1)
  97. h = max(h, 1)
  98. mask = mask.expand((1, 1, -1, -1))
  99. mask = mask.to(torch.float32)
  100. mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
  101. mask = mask[0][0]
  102. if thresh >= 0:
  103. mask = mask > thresh
  104. else:
  105. mask = (mask * 255).to(torch.uint8)
  106. im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
  107. x_0 = max(box[0], 0)
  108. x_1 = min(box[2] + 1, im_w)
  109. y_0 = max(box[1], 0)
  110. y_1 = min(box[3] + 1, im_h)
  111. im_mask[y_0:y_1, x_0:x_1] = mask[
  112. (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
  113. ]
  114. return im_mask
  115. class Masker:
  116. def __init__(self, threshold=0.5, padding=1):
  117. self.threshold = threshold
  118. self.padding = padding
  119. def forward_single_image(self, masks, boxes):
  120. boxes = boxes.convert("xyxy")
  121. im_w, im_h = boxes.size
  122. res = [
  123. paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
  124. for mask, box in zip(masks, boxes.bbox)
  125. ]
  126. if len(res) > 0:
  127. res = torch.stack(*res, dim=0)[:, None]
  128. else:
  129. res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
  130. return Tensor(res.numpy())
  131. def __call__(self, masks, boxes):
  132. if isinstance(boxes, BoxList):
  133. boxes = [boxes]
  134. results = []
  135. for mask, box in zip(masks, boxes):
  136. result = self.forward_single_image(mask, box)
  137. results.append(result)
  138. return results
  139. masker = Masker(threshold=0.5, padding=1)
  140. def select_top_predictions(predictions, confidence_threshold=0.9):
  141. scores = predictions.get_field("scores").numpy()
  142. keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
  143. return predictions[keep]
  144. def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
  145. image = transforms(size_scale)(original_image).numpy()
  146. image = Tensor(image, requires_grad=False)
  147. predictions = model(image)
  148. prediction = predictions[0]
  149. prediction = select_top_predictions(prediction, confidence_threshold)
  150. width, height = original_image.size
  151. prediction = prediction.resize((width, height))
  152. if prediction.has_field("mask"):
  153. masks = prediction.get_field("mask")
  154. masks = masker([masks], [prediction])[0]
  155. prediction.add_field("mask", masks)
  156. return prediction
  157. def compute_prediction_batched(batch, model, size_scale=1.0):
  158. imgs = []
  159. for img in batch:
  160. imgs.append(transforms(size_scale)(img).numpy())
  161. image = [Tensor(image, requires_grad=False) for image in imgs]
  162. predictions = model(image)
  163. del image
  164. return predictions
  165. palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
  166. def findContours(*args, **kwargs):
  167. if cv2.__version__.startswith('4'):
  168. contours, hierarchy = cv2.findContours(*args, **kwargs)
  169. elif cv2.__version__.startswith('3'):
  170. _, contours, hierarchy = cv2.findContours(*args, **kwargs)
  171. return contours, hierarchy
  172. def compute_colors_for_labels(labels):
  173. l = labels[:, None]
  174. colors = l * palette
  175. colors = (colors % 255).astype("uint8")
  176. return colors
  177. def overlay_mask(image, predictions):
  178. image = np.asarray(image)
  179. masks = predictions.get_field("mask").numpy()
  180. labels = predictions.get_field("labels").numpy()
  181. colors = compute_colors_for_labels(labels).tolist()
  182. for mask, color in zip(masks, colors):
  183. thresh = mask[0, :, :, None]
  184. contours, hierarchy = findContours(
  185. thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
  186. )
  187. image = cv2.drawContours(image, contours, -1, color, 3)
  188. composite = image
  189. return composite
  190. CATEGORIES = [
  191. "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
  192. "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
  193. "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
  194. "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
  195. "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
  196. "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
  197. "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
  198. "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
  199. ]
  200. def overlay_boxes(image, predictions):
  201. labels = predictions.get_field("labels").numpy()
  202. boxes = predictions.bbox
  203. image = np.asarray(image)
  204. colors = compute_colors_for_labels(labels).tolist()
  205. for box, color in zip(boxes, colors):
  206. box = torch.tensor(box.numpy())
  207. box = box.to(torch.int64)
  208. top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
  209. image = cv2.rectangle(
  210. image, tuple(top_left), tuple(bottom_right), tuple(color), 1
  211. )
  212. return image
  213. def overlay_class_names(image, predictions):
  214. scores = predictions.get_field("scores").numpy().tolist()
  215. labels = predictions.get_field("labels").numpy().tolist()
  216. labels = [CATEGORIES[int(i)] for i in labels]
  217. boxes = predictions.bbox.numpy()
  218. image = np.asarray(image)
  219. template = "{}: {:.2f}"
  220. for box, score, label in zip(boxes, scores, labels):
  221. x, y = box[:2]
  222. s = template.format(label, score)
  223. x, y = int(x), int(y)
  224. cv2.putText(
  225. image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
  226. )
  227. return image
  228. if __name__ == '__main__':
  229. parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  230. parser.add_argument('--image', type=str, help="Path of the image to run")
  231. parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
  232. parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
  233. parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
  234. args = parser.parse_args()
  235. resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
  236. model_tiny = MaskRCNN(resnet)
  237. model_tiny.load_from_pretrained()
  238. img = Image.open(args.image)
  239. top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
  240. bbox_image = overlay_boxes(img, top_result_tiny)
  241. mask_image = overlay_mask(bbox_image, top_result_tiny)
  242. final_image = overlay_class_names(mask_image, top_result_tiny)
  243. im = Image.fromarray(final_image)
  244. print(f"saving {args.out}")
  245. im.save(args.out)
  246. im.show()