| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271 |
- import re
- import math
- import os
- import numpy as np
- from pathlib import Path
- from tinygrad import nn, Tensor, dtypes
- from tinygrad.tensor import _to_np_dtype
- from tinygrad.helpers import get_child, fetch
- from tinygrad.nn.state import torch_load
- from extra.models.resnet import ResNet
- from extra.models.retinanet import nms as _box_nms
- USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
- def rint(tensor):
- x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
- return (x<0).where(x.floor(), x.ceil())
- def nearest_interpolate(tensor, scale_factor):
- bs, c, py, px = tensor.shape
- return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
- def meshgrid(x, y):
- grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
- grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
- return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
- def topk(input_, k, dim=-1, largest=True, sorted=False):
- k = min(k, input_.shape[dim]-1)
- input_ = input_.numpy()
- if largest: input_ *= -1
- ind = np.argpartition(input_, k, axis=dim)
- if largest: input_ *= -1
- ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
- input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
- if not sorted: return Tensor(input_), ind
- if largest: input_ *= -1
- ind_part = np.argsort(input_, axis=dim)
- ind = np.take_along_axis(ind, ind_part, axis=dim)
- if largest: input_ *= -1
- val = np.take_along_axis(input_, ind_part, axis=dim)
- return Tensor(val), ind
- # This is very slow for large arrays, or indices
- def _gather(array, indices):
- indices = indices.float().to(array.device)
- reshape_arg = [1]*array.ndim + [array.shape[-1]]
- return Tensor.where(
- indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
- array, 0,
- ).sum(indices.ndim)
- # TODO: replace npgather with a faster gather using tinygrad only
- # NOTE: this blocks the gradient
- def npgather(array,indices):
- if isinstance(array, Tensor): array = array.numpy()
- if isinstance(indices, Tensor): indices = indices.numpy()
- if isinstance(indices, list): indices = np.asarray(indices)
- return Tensor(array[indices.astype(int)])
- def get_strides(shape):
- prod = [1]
- for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
- # something about ints is broken with gpu, cuda
- return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0)
- # with keys as integer array for all axes
- def tensor_getitem(tensor, *keys):
- # something about ints is broken with gpu, cuda
- flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
- strides = get_strides(tensor.shape)
- idxs = (flat_keys * strides).sum(1)
- gatherer = npgather if USE_NP_GATHER else _gather
- return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)
- # for gather with indicies only on axis=0
- def tensor_gather(tensor, indices):
- if not isinstance(indices, Tensor):
- indices = Tensor(indices, requires_grad=False)
- if len(tensor.shape) > 2:
- rem_shape = list(tensor.shape)[1:]
- tensor = tensor.reshape(tensor.shape[0], -1)
- else:
- rem_shape = None
- if len(tensor.shape) > 1:
- tensor = tensor.T
- repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
- indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
- ret = _gather(tensor, indices)
- if rem_shape:
- ret = ret.reshape([indices.shape[0]] + rem_shape)
- else:
- ret = _gather(tensor, indices)
- del indices
- return ret
- class LastLevelMaxPool:
- def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
- # transpose
- FLIP_LEFT_RIGHT = 0
- FLIP_TOP_BOTTOM = 1
- def permute_and_flatten(layer:Tensor, N, A, C, H, W):
- layer = layer.reshape(N, -1, C, H, W)
- layer = layer.permute(0, 3, 4, 1, 2)
- layer = layer.reshape(N, -1, C)
- return layer
- class BoxList:
- def __init__(self, bbox, image_size, mode="xyxy"):
- if not isinstance(bbox, Tensor):
- bbox = Tensor(bbox)
- if bbox.ndim != 2:
- raise ValueError(
- "bbox should have 2 dimensions, got {}".format(bbox.ndim)
- )
- if bbox.shape[-1] != 4:
- raise ValueError(
- "last dimenion of bbox should have a "
- "size of 4, got {}".format(bbox.shape[-1])
- )
- if mode not in ("xyxy", "xywh"):
- raise ValueError("mode should be 'xyxy' or 'xywh'")
- self.bbox = bbox
- self.size = image_size # (image_width, image_height)
- self.mode = mode
- self.extra_fields = {}
- def __repr__(self):
- s = self.__class__.__name__ + "("
- s += "num_boxes={}, ".format(len(self))
- s += "image_width={}, ".format(self.size[0])
- s += "image_height={}, ".format(self.size[1])
- s += "mode={})".format(self.mode)
- return s
- def area(self):
- box = self.bbox
- if self.mode == "xyxy":
- TO_REMOVE = 1
- area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
- elif self.mode == "xywh":
- area = box[:, 2] * box[:, 3]
- return area
- def add_field(self, field, field_data):
- self.extra_fields[field] = field_data
- def get_field(self, field):
- return self.extra_fields[field]
- def has_field(self, field):
- return field in self.extra_fields
- def fields(self):
- return list(self.extra_fields.keys())
- def _copy_extra_fields(self, bbox):
- for k, v in bbox.extra_fields.items():
- self.extra_fields[k] = v
- def convert(self, mode):
- if mode == self.mode:
- return self
- xmin, ymin, xmax, ymax = self._split_into_xyxy()
- if mode == "xyxy":
- bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
- bbox = BoxList(bbox, self.size, mode=mode)
- else:
- TO_REMOVE = 1
- bbox = Tensor.cat(
- *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
- )
- bbox = BoxList(bbox, self.size, mode=mode)
- bbox._copy_extra_fields(self)
- return bbox
- def _split_into_xyxy(self):
- if self.mode == "xyxy":
- xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
- return xmin, ymin, xmax, ymax
- if self.mode == "xywh":
- TO_REMOVE = 1
- xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
- return (
- xmin,
- ymin,
- xmin + (w - TO_REMOVE).clamp(min=0),
- ymin + (h - TO_REMOVE).clamp(min=0),
- )
- def resize(self, size, *args, **kwargs):
- ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
- if ratios[0] == ratios[1]:
- ratio = ratios[0]
- scaled_box = self.bbox * ratio
- bbox = BoxList(scaled_box, size, mode=self.mode)
- for k, v in self.extra_fields.items():
- if not isinstance(v, Tensor):
- v = v.resize(size, *args, **kwargs)
- bbox.add_field(k, v)
- return bbox
- ratio_width, ratio_height = ratios
- xmin, ymin, xmax, ymax = self._split_into_xyxy()
- scaled_xmin = xmin * ratio_width
- scaled_xmax = xmax * ratio_width
- scaled_ymin = ymin * ratio_height
- scaled_ymax = ymax * ratio_height
- scaled_box = Tensor.cat(
- *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
- )
- bbox = BoxList(scaled_box, size, mode="xyxy")
- for k, v in self.extra_fields.items():
- if not isinstance(v, Tensor):
- v = v.resize(size, *args, **kwargs)
- bbox.add_field(k, v)
- return bbox.convert(self.mode)
- def transpose(self, method):
- image_width, image_height = self.size
- xmin, ymin, xmax, ymax = self._split_into_xyxy()
- if method == FLIP_LEFT_RIGHT:
- TO_REMOVE = 1
- transposed_xmin = image_width - xmax - TO_REMOVE
- transposed_xmax = image_width - xmin - TO_REMOVE
- transposed_ymin = ymin
- transposed_ymax = ymax
- elif method == FLIP_TOP_BOTTOM:
- transposed_xmin = xmin
- transposed_xmax = xmax
- transposed_ymin = image_height - ymax
- transposed_ymax = image_height - ymin
- transposed_boxes = Tensor.cat(
- *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
- )
- bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
- for k, v in self.extra_fields.items():
- if not isinstance(v, Tensor):
- v = v.transpose(method)
- bbox.add_field(k, v)
- return bbox.convert(self.mode)
- def clip_to_image(self, remove_empty=True):
- TO_REMOVE = 1
- bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
- bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
- bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
- bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
- self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
- if remove_empty:
- box = self.bbox
- keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
- return self[keep]
- return self
- def __getitem__(self, item):
- if isinstance(item, list):
- if len(item) == 0:
- return []
- if sum(item) == len(item) and isinstance(item[0], bool):
- return self
- bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
- for k, v in self.extra_fields.items():
- bbox.add_field(k, tensor_gather(v, item))
- return bbox
- def __len__(self):
- return self.bbox.shape[0]
- def cat_boxlist(bboxes):
- size = bboxes[0].size
- mode = bboxes[0].mode
- fields = set(bboxes[0].fields())
- cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]
- if len(cat_box_list) > 0:
- cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
- else:
- cat_boxes = BoxList(bboxes[0].bbox, size, mode)
- for field in fields:
- cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
- if len(cat_box_list) > 0:
- data = Tensor.cat(*cat_field_list, dim=0)
- else:
- data = bboxes[0].get_field(field)
- cat_boxes.add_field(field, data)
- return cat_boxes
- class FPN:
- def __init__(self, in_channels_list, out_channels):
- self.inner_blocks, self.layer_blocks = [], []
- for in_channels in in_channels_list:
- self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
- self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
- self.top_block = LastLevelMaxPool()
- def __call__(self, x: Tensor):
- last_inner = self.inner_blocks[-1](x[-1])
- results = []
- results.append(self.layer_blocks[-1](last_inner))
- for feature, inner_block, layer_block in zip(
- x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
- ):
- if not inner_block:
- continue
- inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
- inner_lateral = inner_block(feature)
- last_inner = inner_lateral + inner_top_down
- layer_result = layer_block(last_inner)
- results.insert(0, layer_result)
- last_results = self.top_block(results[-1])
- results.extend(last_results)
- return tuple(results)
- class ResNetFPN:
- def __init__(self, resnet, out_channels=256):
- self.out_channels = out_channels
- self.body = resnet
- in_channels_stage2 = 256
- in_channels_list = [
- in_channels_stage2,
- in_channels_stage2 * 2,
- in_channels_stage2 * 4,
- in_channels_stage2 * 8,
- ]
- self.fpn = FPN(in_channels_list, out_channels)
- def __call__(self, x):
- x = self.body(x)
- return self.fpn(x)
- class AnchorGenerator:
- def __init__(
- self,
- sizes=(32, 64, 128, 256, 512),
- aspect_ratios=(0.5, 1.0, 2.0),
- anchor_strides=(4, 8, 16, 32, 64),
- straddle_thresh=0,
- ):
- if len(anchor_strides) == 1:
- anchor_stride = anchor_strides[0]
- cell_anchors = [
- generate_anchors(anchor_stride, sizes, aspect_ratios)
- ]
- else:
- if len(anchor_strides) != len(sizes):
- raise RuntimeError("FPN should have #anchor_strides == #sizes")
- cell_anchors = [
- generate_anchors(
- anchor_stride,
- size if isinstance(size, (tuple, list)) else (size,),
- aspect_ratios
- )
- for anchor_stride, size in zip(anchor_strides, sizes)
- ]
- self.strides = anchor_strides
- self.cell_anchors = cell_anchors
- self.straddle_thresh = straddle_thresh
- def num_anchors_per_location(self):
- return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]
- def grid_anchors(self, grid_sizes):
- anchors = []
- for size, stride, base_anchors in zip(
- grid_sizes, self.strides, self.cell_anchors
- ):
- grid_height, grid_width = size
- device = base_anchors.device
- shifts_x = Tensor.arange(
- start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
- )
- shifts_y = Tensor.arange(
- start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
- )
- shift_y, shift_x = meshgrid(shifts_y, shifts_x)
- shift_x = shift_x.reshape(-1)
- shift_y = shift_y.reshape(-1)
- shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)
- anchors.append(
- (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
- )
- return anchors
- def add_visibility_to(self, boxlist):
- image_width, image_height = boxlist.size
- anchors = boxlist.bbox
- if self.straddle_thresh >= 0:
- inds_inside = (
- (anchors[:, 0] >= -self.straddle_thresh)
- * (anchors[:, 1] >= -self.straddle_thresh)
- * (anchors[:, 2] < image_width + self.straddle_thresh)
- * (anchors[:, 3] < image_height + self.straddle_thresh)
- )
- else:
- device = anchors.device
- inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
- boxlist.add_field("visibility", inds_inside)
- def __call__(self, image_list, feature_maps):
- grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
- anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
- anchors = []
- for (image_height, image_width) in image_list.image_sizes:
- anchors_in_image = []
- for anchors_per_feature_map in anchors_over_all_feature_maps:
- boxlist = BoxList(
- anchors_per_feature_map, (image_width, image_height), mode="xyxy"
- )
- self.add_visibility_to(boxlist)
- anchors_in_image.append(boxlist)
- anchors.append(anchors_in_image)
- return anchors
- def generate_anchors(
- stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
- ):
- return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
- def _generate_anchors(base_size, scales, aspect_ratios):
- anchor = Tensor([1, 1, base_size, base_size]) - 1
- anchors = _ratio_enum(anchor, aspect_ratios)
- anchors = Tensor.cat(
- *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
- )
- return anchors
- def _whctrs(anchor):
- w = anchor[2] - anchor[0] + 1
- h = anchor[3] - anchor[1] + 1
- x_ctr = anchor[0] + 0.5 * (w - 1)
- y_ctr = anchor[1] + 0.5 * (h - 1)
- return w, h, x_ctr, y_ctr
- def _mkanchors(ws, hs, x_ctr, y_ctr):
- ws = ws[:, None]
- hs = hs[:, None]
- anchors = Tensor.cat(*(
- x_ctr - 0.5 * (ws - 1),
- y_ctr - 0.5 * (hs - 1),
- x_ctr + 0.5 * (ws - 1),
- y_ctr + 0.5 * (hs - 1),
- ), dim=1)
- return anchors
- def _ratio_enum(anchor, ratios):
- w, h, x_ctr, y_ctr = _whctrs(anchor)
- size = w * h
- size_ratios = size / ratios
- ws = rint(Tensor.sqrt(size_ratios))
- hs = rint(ws * ratios)
- anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
- return anchors
- def _scale_enum(anchor, scales):
- w, h, x_ctr, y_ctr = _whctrs(anchor)
- ws = w * scales
- hs = h * scales
- anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
- return anchors
- class RPNHead:
- def __init__(self, in_channels, num_anchors):
- self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
- self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
- self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)
- def __call__(self, x):
- logits = []
- bbox_reg = []
- for feature in x:
- t = Tensor.relu(self.conv(feature))
- logits.append(self.cls_logits(t))
- bbox_reg.append(self.bbox_pred(t))
- return logits, bbox_reg
- class BoxCoder(object):
- def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
- self.weights = weights
- self.bbox_xform_clip = bbox_xform_clip
- def encode(self, reference_boxes, proposals):
- TO_REMOVE = 1 # TODO remove
- ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
- ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
- ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
- ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
- gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
- gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
- gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
- gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
- wx, wy, ww, wh = self.weights
- targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
- targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
- targets_dw = ww * Tensor.log(gt_widths / ex_widths)
- targets_dh = wh * Tensor.log(gt_heights / ex_heights)
- targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
- return targets
- def decode(self, rel_codes, boxes):
- boxes = boxes.cast(rel_codes.dtype)
- rel_codes = rel_codes
- TO_REMOVE = 1 # TODO remove
- widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
- heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
- ctr_x = boxes[:, 0] + 0.5 * widths
- ctr_y = boxes[:, 1] + 0.5 * heights
- wx, wy, ww, wh = self.weights
- dx = rel_codes[:, 0::4] / wx
- dy = rel_codes[:, 1::4] / wy
- dw = rel_codes[:, 2::4] / ww
- dh = rel_codes[:, 3::4] / wh
- # Prevent sending too large values into Tensor.exp()
- dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
- dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)
- pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
- pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
- pred_w = dw.exp() * widths[:, None]
- pred_h = dh.exp() * heights[:, None]
- x = pred_ctr_x - 0.5 * pred_w
- y = pred_ctr_y - 0.5 * pred_h
- w = pred_ctr_x + 0.5 * pred_w - 1
- h = pred_ctr_y + 0.5 * pred_h - 1
- pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
- return pred_boxes
- def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
- if nms_thresh <= 0:
- return boxlist
- mode = boxlist.mode
- boxlist = boxlist.convert("xyxy")
- boxes = boxlist.bbox
- score = boxlist.get_field(score_field)
- keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
- if max_proposals > 0:
- keep = keep[:max_proposals]
- boxlist = boxlist[keep]
- return boxlist.convert(mode)
- def remove_small_boxes(boxlist, min_size):
- xywh_boxes = boxlist.convert("xywh").bbox
- _, _, ws, hs = xywh_boxes.chunk(4, dim=1)
- keep = ((
- (ws >= min_size) * (hs >= min_size)
- ) > 0).reshape(-1)
- if keep.sum().numpy() == len(boxlist):
- return boxlist
- else:
- keep = keep.numpy().nonzero()[0]
- return boxlist[keep]
- class RPNPostProcessor:
- # Not used in Loss calculation
- def __init__(
- self,
- pre_nms_top_n,
- post_nms_top_n,
- nms_thresh,
- min_size,
- box_coder=None,
- fpn_post_nms_top_n=None,
- ):
- self.pre_nms_top_n = pre_nms_top_n
- self.post_nms_top_n = post_nms_top_n
- self.nms_thresh = nms_thresh
- self.min_size = min_size
- if box_coder is None:
- box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
- self.box_coder = box_coder
- if fpn_post_nms_top_n is None:
- fpn_post_nms_top_n = post_nms_top_n
- self.fpn_post_nms_top_n = fpn_post_nms_top_n
- def forward_for_single_feature_map(self, anchors, objectness, box_regression):
- device = objectness.device
- N, A, H, W = objectness.shape
- objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
- objectness = objectness.sigmoid()
- box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
- num_anchors = A * H * W
- pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
- objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
- concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
- image_shapes = [box.size for box in anchors]
- box_regression_list = []
- concat_anchors_list = []
- for batch_idx in range(N):
- box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
- concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
- box_regression = Tensor.stack(*box_regression_list)
- concat_anchors = Tensor.stack(*concat_anchors_list)
- proposals = self.box_coder.decode(
- box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
- )
- proposals = proposals.reshape(N, -1, 4)
- result = []
- for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
- boxlist = BoxList(proposal, im_shape, mode="xyxy")
- boxlist.add_field("objectness", score)
- boxlist = boxlist.clip_to_image(remove_empty=False)
- boxlist = remove_small_boxes(boxlist, self.min_size)
- boxlist = boxlist_nms(
- boxlist,
- self.nms_thresh,
- max_proposals=self.post_nms_top_n,
- score_field="objectness",
- )
- result.append(boxlist)
- return result
- def __call__(self, anchors, objectness, box_regression):
- sampled_boxes = []
- num_levels = len(objectness)
- anchors = list(zip(*anchors))
- for a, o, b in zip(anchors, objectness, box_regression):
- sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
- boxlists = list(zip(*sampled_boxes))
- boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
- if num_levels > 1:
- boxlists = self.select_over_all_levels(boxlists)
- return boxlists
- def select_over_all_levels(self, boxlists):
- num_images = len(boxlists)
- for i in range(num_images):
- objectness = boxlists[i].get_field("objectness")
- post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
- _, inds_sorted = topk(objectness,
- post_nms_top_n, dim=0, sorted=False
- )
- boxlists[i] = boxlists[i][inds_sorted]
- return boxlists
- class RPN:
- def __init__(self, in_channels):
- self.anchor_generator = AnchorGenerator()
- in_channels = 256
- head = RPNHead(
- in_channels, self.anchor_generator.num_anchors_per_location()[0]
- )
- rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
- box_selector_test = RPNPostProcessor(
- pre_nms_top_n=1000,
- post_nms_top_n=1000,
- nms_thresh=0.7,
- min_size=0,
- box_coder=rpn_box_coder,
- fpn_post_nms_top_n=1000
- )
- self.head = head
- self.box_selector_test = box_selector_test
- def __call__(self, images, features, targets=None):
- objectness, rpn_box_regression = self.head(features)
- anchors = self.anchor_generator(images, features)
- boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
- return boxes, {}
- def make_conv3x3(
- in_channels,
- out_channels,
- dilation=1,
- stride=1,
- use_gn=False,
- ):
- conv = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- dilation=dilation,
- bias=False if use_gn else True
- )
- return conv
- class MaskRCNNFPNFeatureExtractor:
- def __init__(self):
- resolution = 14
- scales = (0.25, 0.125, 0.0625, 0.03125)
- sampling_ratio = 2
- pooler = Pooler(
- output_size=(resolution, resolution),
- scales=scales,
- sampling_ratio=sampling_ratio,
- )
- input_size = 256
- self.pooler = pooler
- use_gn = False
- layers = (256, 256, 256, 256)
- dilation = 1
- self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
- self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
- self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
- self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
- self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
- def __call__(self, x, proposals):
- x = self.pooler(x, proposals)
- for layer in self.blocks:
- if x is not None:
- x = Tensor.relu(layer(x))
- return x
- class MaskRCNNC4Predictor:
- def __init__(self):
- num_classes = 81
- dim_reduced = 256
- num_inputs = dim_reduced
- self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
- self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)
- def __call__(self, x):
- x = Tensor.relu(self.conv5_mask(x))
- return self.mask_fcn_logits(x)
- class FPN2MLPFeatureExtractor:
- def __init__(self, cfg):
- resolution = 7
- scales = (0.25, 0.125, 0.0625, 0.03125)
- sampling_ratio = 2
- pooler = Pooler(
- output_size=(resolution, resolution),
- scales=scales,
- sampling_ratio=sampling_ratio,
- )
- input_size = 256 * resolution ** 2
- representation_size = 1024
- self.pooler = pooler
- self.fc6 = nn.Linear(input_size, representation_size)
- self.fc7 = nn.Linear(representation_size, representation_size)
- def __call__(self, x, proposals):
- x = self.pooler(x, proposals)
- x = x.reshape(x.shape[0], -1)
- x = Tensor.relu(self.fc6(x))
- x = Tensor.relu(self.fc7(x))
- return x
- def _bilinear_interpolate(
- input, # [N, C, H, W]
- roi_batch_ind, # [K]
- y, # [K, PH, IY]
- x, # [K, PW, IX]
- ymask, # [K, IY]
- xmask, # [K, IX]
- ):
- _, channels, height, width = input.shape
- y = y.clip(min_=0.0, max_=float(height-1))
- x = x.clip(min_=0.0, max_=float(width-1))
- # Tensor.where doesnt work well with int32 data so cast to float32
- y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
- x_low = x.cast(dtypes.int32).contiguous().float().contiguous()
- y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
- y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)
- x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
- x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)
- ly = y - y_low
- lx = x - x_low
- hy = 1.0 - ly
- hx = 1.0 - lx
- def masked_index(
- y, # [K, PH, IY]
- x, # [K, PW, IX]
- ):
- if ymask is not None:
- assert xmask is not None
- y = Tensor.where(ymask[:, None, :], y, 0)
- x = Tensor.where(xmask[:, None, :], x, 0)
- key1 = roi_batch_ind[:, None, None, None, None, None]
- key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
- key3 = y[:, None, :, None, :, None]
- key4 = x[:, None, None, :, None, :]
- return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
- v1 = masked_index(y_low, x_low)
- v2 = masked_index(y_low, x_high)
- v3 = masked_index(y_high, x_low)
- v4 = masked_index(y_high, x_high)
- # all ws preemptively [K, C, PH, PW, IY, IX]
- def outer_prod(y, x):
- return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
- w1 = outer_prod(hy, hx)
- w2 = outer_prod(hy, lx)
- w3 = outer_prod(ly, hx)
- w4 = outer_prod(ly, lx)
- val = w1*v1 + w2*v2 + w3*v3 + w4*v4
- return val
- #https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
- def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
- orig_dtype = input.dtype
- _, _, height, width = input.shape
- ph = Tensor.arange(pooled_height, device=input.device)
- pw = Tensor.arange(pooled_width, device=input.device)
- roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
- offset = 0.5 if aligned else 0.0
- roi_start_w = rois[:, 1] * spatial_scale - offset
- roi_start_h = rois[:, 2] * spatial_scale - offset
- roi_end_w = rois[:, 3] * spatial_scale - offset
- roi_end_h = rois[:, 4] * spatial_scale - offset
- roi_width = roi_end_w - roi_start_w
- roi_height = roi_end_h - roi_start_h
- if not aligned:
- roi_width = roi_width.maximum(1.0)
- roi_height = roi_height.maximum(1.0)
- bin_size_h = roi_height / pooled_height
- bin_size_w = roi_width / pooled_width
- exact_sampling = sampling_ratio > 0
- roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
- roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
- if exact_sampling:
- count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
- iy = Tensor.arange(roi_bin_grid_h, device=input.device)
- ix = Tensor.arange(roi_bin_grid_w, device=input.device)
- ymask = None
- xmask = None
- else:
- count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
- iy = Tensor.arange(height, device=input.device)
- ix = Tensor.arange(width, device=input.device)
- ymask = iy[None, :] < roi_bin_grid_h[:, None]
- xmask = ix[None, :] < roi_bin_grid_w[:, None]
- def from_K(t):
- return t[:, None, None]
- y = (
- from_K(roi_start_h)
- + ph[None, :, None] * from_K(bin_size_h)
- + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
- )
- x = (
- from_K(roi_start_w)
- + pw[None, :, None] * from_K(bin_size_w)
- + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
- )
- val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
- if not exact_sampling:
- val = ymask[:, None, None, None, :, None].where(val, 0)
- val = xmask[:, None, None, None, None, :].where(val, 0)
- output = val.sum((-1, -2))
- if isinstance(count, Tensor):
- output /= count[:, None, None, None]
- else:
- output /= count
- output = output.cast(orig_dtype)
- return output
- class ROIAlign:
- def __init__(self, output_size, spatial_scale, sampling_ratio):
- self.output_size = output_size
- self.spatial_scale = spatial_scale
- self.sampling_ratio = sampling_ratio
- def __call__(self, input, rois):
- output = _roi_align(
- input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
- )
- return output
- class LevelMapper:
- def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
- self.k_min = k_min
- self.k_max = k_max
- self.s0 = canonical_scale
- self.lvl0 = canonical_level
- self.eps = eps
- def __call__(self, boxlists):
- s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
- target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
- target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
- return target_lvls - self.k_min
- class Pooler:
- def __init__(self, output_size, scales, sampling_ratio):
- self.output_size = output_size
- self.scales = scales
- self.sampling_ratio = sampling_ratio
- poolers = []
- for scale in scales:
- poolers.append(
- ROIAlign(
- output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
- )
- )
- self.poolers = poolers
- self.output_size = output_size
- lvl_min = -math.log2(scales[0])
- lvl_max = -math.log2(scales[-1])
- self.map_levels = LevelMapper(lvl_min, lvl_max)
- def convert_to_roi_format(self, boxes):
- concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
- device, dtype = concat_boxes.device, concat_boxes.dtype
- ids = Tensor.cat(
- *[
- Tensor.full((len(b), 1), i, dtype=dtype, device=device)
- for i, b in enumerate(boxes)
- ],
- dim=0,
- )
- if concat_boxes.shape[0] != 0:
- rois = Tensor.cat(*[ids, concat_boxes], dim=1)
- return rois
- def __call__(self, x, boxes):
- num_levels = len(self.poolers)
- rois = self.convert_to_roi_format(boxes)
- if rois is not None:
- if num_levels == 1:
- return self.poolers[0](x[0], rois)
- levels = self.map_levels(boxes)
- results = []
- all_idxs = []
- for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
- # this is fine because no grad will flow through index
- idx_in_level = (levels.numpy() == level).nonzero()[0]
- if len(idx_in_level) > 0:
- rois_per_level = tensor_gather(rois, idx_in_level)
- pooler_output = pooler(per_level_feature, rois_per_level)
- all_idxs.extend(idx_in_level)
- results.append(pooler_output)
- return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
- class FPNPredictor:
- def __init__(self):
- num_classes = 81
- representation_size = 1024
- self.cls_score = nn.Linear(representation_size, num_classes)
- num_bbox_reg_classes = num_classes
- self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
- def __call__(self, x):
- scores = self.cls_score(x)
- bbox_deltas = self.bbox_pred(x)
- return scores, bbox_deltas
- class PostProcessor:
- # Not used in training
- def __init__(
- self,
- score_thresh=0.05,
- nms=0.5,
- detections_per_img=100,
- box_coder=None,
- cls_agnostic_bbox_reg=False
- ):
- self.score_thresh = score_thresh
- self.nms = nms
- self.detections_per_img = detections_per_img
- if box_coder is None:
- box_coder = BoxCoder(weights=(10., 10., 5., 5.))
- self.box_coder = box_coder
- self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
- def __call__(self, x, boxes):
- class_logits, box_regression = x
- class_prob = Tensor.softmax(class_logits, -1)
- image_shapes = [box.size for box in boxes]
- boxes_per_image = [len(box) for box in boxes]
- concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)
- if self.cls_agnostic_bbox_reg:
- box_regression = box_regression[:, -4:]
- proposals = self.box_coder.decode(
- box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
- )
- if self.cls_agnostic_bbox_reg:
- proposals = proposals.repeat([1, class_prob.shape[1]])
- num_classes = class_prob.shape[1]
- proposals = proposals.unsqueeze(0)
- class_prob = class_prob.unsqueeze(0)
- results = []
- for prob, boxes_per_img, image_shape in zip(
- class_prob, proposals, image_shapes
- ):
- boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
- boxlist = boxlist.clip_to_image(remove_empty=False)
- boxlist = self.filter_results(boxlist, num_classes)
- results.append(boxlist)
- return results
- def prepare_boxlist(self, boxes, scores, image_shape):
- boxes = boxes.reshape(-1, 4)
- scores = scores.reshape(-1)
- boxlist = BoxList(boxes, image_shape, mode="xyxy")
- boxlist.add_field("scores", scores)
- return boxlist
- def filter_results(self, boxlist, num_classes):
- boxes = boxlist.bbox.reshape(-1, num_classes * 4)
- scores = boxlist.get_field("scores").reshape(-1, num_classes)
- device = scores.device
- result = []
- scores = scores.numpy()
- boxes = boxes.numpy()
- inds_all = scores > self.score_thresh
- for j in range(1, num_classes):
- inds = inds_all[:, j].nonzero()[0]
- # This needs to be done in numpy because it can create empty arrays
- scores_j = scores[inds, j]
- boxes_j = boxes[inds, j * 4: (j + 1) * 4]
- boxes_j = Tensor(boxes_j)
- scores_j = Tensor(scores_j)
- boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
- boxlist_for_class.add_field("scores", scores_j)
- if len(boxlist_for_class):
- boxlist_for_class = boxlist_nms(
- boxlist_for_class, self.nms
- )
- num_labels = len(boxlist_for_class)
- boxlist_for_class.add_field(
- "labels", Tensor.full((num_labels,), j, device=device)
- )
- result.append(boxlist_for_class)
- result = cat_boxlist(result)
- number_of_detections = len(result)
- if number_of_detections > self.detections_per_img > 0:
- cls_scores = result.get_field("scores")
- image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
- image_thresh = image_thresh.numpy()[-1]
- keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
- result = result[keep]
- return result
- class RoIBoxHead:
- def __init__(self, in_channels):
- self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
- self.predictor = FPNPredictor()
- self.post_processor = PostProcessor(
- score_thresh=0.05,
- nms=0.5,
- detections_per_img=100,
- box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
- cls_agnostic_bbox_reg=False
- )
- def __call__(self, features, proposals, targets=None):
- x = self.feature_extractor(features, proposals)
- class_logits, box_regression = self.predictor(x)
- if not Tensor.training:
- result = self.post_processor((class_logits, box_regression), proposals)
- return x, result, {}
- class MaskPostProcessor:
- # Not used in loss calculation
- def __call__(self, x, boxes):
- mask_prob = x.sigmoid().numpy()
- num_masks = x.shape[0]
- labels = [bbox.get_field("labels") for bbox in boxes]
- labels = Tensor.cat(*labels).numpy().astype(np.int32)
- index = np.arange(num_masks)
- mask_prob = mask_prob[index, labels][:, None]
- boxes_per_image, cumsum = [], 0
- for box in boxes:
- cumsum += len(box)
- boxes_per_image.append(cumsum)
- # using numpy here as Tensor.chunk doesnt have custom chunk sizes
- mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
- results = []
- for prob, box in zip(mask_prob, boxes):
- bbox = BoxList(box.bbox, box.size, mode="xyxy")
- for field in box.fields():
- bbox.add_field(field, box.get_field(field))
- prob = Tensor(prob)
- bbox.add_field("mask", prob)
- results.append(bbox)
- return results
- class Mask:
- def __init__(self):
- self.feature_extractor = MaskRCNNFPNFeatureExtractor()
- self.predictor = MaskRCNNC4Predictor()
- self.post_processor = MaskPostProcessor()
- def __call__(self, features, proposals, targets=None):
- x = self.feature_extractor(features, proposals)
- if x:
- mask_logits = self.predictor(x)
- if not Tensor.training:
- result = self.post_processor(mask_logits, proposals)
- return x, result, {}
- return x, [], {}
- class RoIHeads:
- def __init__(self, in_channels):
- self.box = RoIBoxHead(in_channels)
- self.mask = Mask()
- def __call__(self, features, proposals, targets=None):
- x, detections, _ = self.box(features, proposals, targets)
- x, detections, _ = self.mask(features, detections, targets)
- return x, detections, {}
- class ImageList(object):
- def __init__(self, tensors, image_sizes):
- self.tensors = tensors
- self.image_sizes = image_sizes
- def to(self, *args, **kwargs):
- cast_tensor = self.tensors.to(*args, **kwargs)
- return ImageList(cast_tensor, self.image_sizes)
- def to_image_list(tensors, size_divisible=32):
- # Preprocessing
- if isinstance(tensors, Tensor) and size_divisible > 0:
- tensors = [tensors]
- if isinstance(tensors, ImageList):
- return tensors
- elif isinstance(tensors, Tensor):
- # single tensor shape can be inferred
- assert tensors.ndim == 4
- image_sizes = [tensor.shape[-2:] for tensor in tensors]
- return ImageList(tensors, image_sizes)
- elif isinstance(tensors, (tuple, list)):
- max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
- if size_divisible > 0:
- stride = size_divisible
- max_size = list(max_size)
- max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
- max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
- max_size = tuple(max_size)
- batch_shape = (len(tensors),) + max_size
- batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
- for img, pad_img in zip(tensors, batched_imgs):
- pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()
- batched_imgs = Tensor(batched_imgs)
- image_sizes = [im.shape[-2:] for im in tensors]
- return ImageList(batched_imgs, image_sizes)
- else:
- raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
- class MaskRCNN:
- def __init__(self, backbone: ResNet):
- self.backbone = ResNetFPN(backbone, out_channels=256)
- self.rpn = RPN(self.backbone.out_channels)
- self.roi_heads = RoIHeads(self.backbone.out_channels)
- def load_from_pretrained(self):
- fn = Path('./') / "weights/maskrcnn.pt"
- fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
- state_dict = torch_load(fn)['model']
- loaded_keys = []
- for k, v in state_dict.items():
- if "module." in k:
- k = k.replace("module.", "")
- if "stem." in k:
- k = k.replace("stem.", "")
- if "fpn_inner" in k:
- block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
- k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
- if "fpn_layer" in k:
- block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
- k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
- loaded_keys.append(k)
- get_child(self, k).assign(v.numpy()).realize()
- return loaded_keys
- def __call__(self, images):
- images = to_image_list(images)
- features = self.backbone(images.tensors)
- proposals, _ = self.rpn(images, features)
- x, result, _ = self.roi_heads(features, proposals)
- return result
- if __name__ == '__main__':
- resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
- model = MaskRCNN(backbone=resnet)
- model.load_from_pretrained()
|