coco.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import json
  2. import pathlib
  3. import zipfile
  4. import numpy as np
  5. from tinygrad.helpers import fetch
  6. import pycocotools._mask as _mask
  7. from examples.mask_rcnn import Masker
  8. from pycocotools.coco import COCO
  9. from pycocotools.cocoeval import COCOeval
  10. iou = _mask.iou
  11. merge = _mask.merge
  12. frPyObjects = _mask.frPyObjects
  13. BASEDIR = pathlib.Path(__file__).parent / "COCO"
  14. BASEDIR.mkdir(exist_ok=True)
  15. def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
  16. if not pathlib.Path(BASEDIR/'val2017').is_dir():
  17. fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
  18. with zipfile.ZipFile(fn, 'r') as zip_ref:
  19. zip_ref.extractall(BASEDIR)
  20. fn.unlink()
  21. if not pathlib.Path(BASEDIR/'annotations').is_dir():
  22. fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
  23. with zipfile.ZipFile(fn, 'r') as zip_ref:
  24. zip_ref.extractall(BASEDIR)
  25. fn.unlink()
  26. with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
  27. annotations_raw = json.loads(f.read())
  28. images = annotations_raw['images']
  29. categories = annotations_raw['categories']
  30. annotations = annotations_raw['annotations']
  31. file_name_to_id = create_dict('file_name', 'id', images)
  32. id_to_width = create_dict('id', 'width', images)
  33. id_to_height = create_dict('id', 'height', images)
  34. json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
  35. contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
  36. def encode(bimask):
  37. if len(bimask.shape) == 3:
  38. return _mask.encode(bimask)
  39. elif len(bimask.shape) == 2:
  40. h, w = bimask.shape
  41. return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
  42. def decode(rleObjs):
  43. if type(rleObjs) == list:
  44. return _mask.decode(rleObjs)
  45. else:
  46. return _mask.decode([rleObjs])[:,:,0]
  47. def area(rleObjs):
  48. if type(rleObjs) == list:
  49. return _mask.area(rleObjs)
  50. else:
  51. return _mask.area([rleObjs])[0]
  52. def toBbox(rleObjs):
  53. if type(rleObjs) == list:
  54. return _mask.toBbox(rleObjs)
  55. else:
  56. return _mask.toBbox([rleObjs])[0]
  57. def convert_prediction_to_coco_bbox(file_name, prediction):
  58. coco_results = []
  59. try:
  60. original_id = file_name_to_id[file_name]
  61. if len(prediction) == 0:
  62. return coco_results
  63. image_width = id_to_width[original_id]
  64. image_height = id_to_height[original_id]
  65. prediction = prediction.resize((image_width, image_height))
  66. prediction = prediction.convert("xywh")
  67. boxes = prediction.bbox.numpy().tolist()
  68. scores = prediction.get_field("scores").numpy().tolist()
  69. labels = prediction.get_field("labels").numpy().tolist()
  70. mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
  71. coco_results.extend(
  72. [
  73. {
  74. "image_id": original_id,
  75. "category_id": mapped_labels[k],
  76. "bbox": box,
  77. "score": scores[k],
  78. }
  79. for k, box in enumerate(boxes)
  80. ]
  81. )
  82. except Exception as e:
  83. print(file_name, e)
  84. return coco_results
  85. masker = Masker(threshold=0.5, padding=1)
  86. def convert_prediction_to_coco_mask(file_name, prediction):
  87. coco_results = []
  88. try:
  89. original_id = file_name_to_id[file_name]
  90. if len(prediction) == 0:
  91. return coco_results
  92. image_width = id_to_width[original_id]
  93. image_height = id_to_height[original_id]
  94. prediction = prediction.resize((image_width, image_height))
  95. masks = prediction.get_field("mask")
  96. scores = prediction.get_field("scores").numpy().tolist()
  97. labels = prediction.get_field("labels").numpy().tolist()
  98. masks = masker([masks], [prediction])[0].numpy()
  99. rles = [
  100. encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
  101. for mask in masks
  102. ]
  103. for rle in rles:
  104. rle["counts"] = rle["counts"].decode("utf-8")
  105. mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
  106. coco_results.extend(
  107. [
  108. {
  109. "image_id": original_id,
  110. "category_id": mapped_labels[k],
  111. "segmentation": rle,
  112. "score": scores[k],
  113. }
  114. for k, rle in enumerate(rles)
  115. ]
  116. )
  117. except Exception as e:
  118. print(file_name, e)
  119. return coco_results
  120. def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
  121. path = pathlib.Path(json_result_file)
  122. if rm and path.exists(): path.unlink()
  123. with open(path, "a") as f:
  124. for s in coco_results:
  125. f.write(json.dumps(s))
  126. f.write('\n')
  127. def remove_dup(l):
  128. seen = set()
  129. seen_add = seen.add
  130. return [x for x in l if not (x in seen or seen_add(x))]
  131. class NpEncoder(json.JSONEncoder):
  132. def default(self, obj):
  133. if isinstance(obj, np.integer):
  134. return int(obj)
  135. if isinstance(obj, np.floating):
  136. return float(obj)
  137. if isinstance(obj, np.ndarray):
  138. return obj.tolist()
  139. return super(NpEncoder, self).default(obj)
  140. def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
  141. coco_results = []
  142. with open(json_result_file, "r") as f:
  143. for line in f:
  144. coco_results.append(json.loads(line))
  145. coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
  146. set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
  147. unique_list = [json.loads(s) for s in set_of_json]
  148. with open(f'{json_result_file}.flattend', "w") as f:
  149. json.dump(unique_list, f)
  150. coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
  151. coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
  152. coco_eval.evaluate()
  153. coco_eval.accumulate()
  154. coco_eval.summarize()
  155. return coco_eval
  156. def iterate(files, bs=1):
  157. batch = []
  158. for file in files:
  159. batch.append(file)
  160. if len(batch) >= bs: yield batch; batch = []
  161. if len(batch) > 0: yield batch; batch = []