resnet.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import tinygrad.nn as nn
  2. from tinygrad import Tensor, dtypes
  3. from tinygrad.nn.state import torch_load
  4. from tinygrad.helpers import fetch, get_child
  5. # allow monkeypatching in layer implementations
  6. BatchNorm = nn.BatchNorm2d
  7. Conv2d = nn.Conv2d
  8. Linear = nn.Linear
  9. class BasicBlock:
  10. expansion = 1
  11. def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
  12. assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
  13. self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  14. self.bn1 = BatchNorm(planes)
  15. self.conv2 = Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
  16. self.bn2 = BatchNorm(planes)
  17. self.downsample = []
  18. if stride != 1 or in_planes != self.expansion*planes:
  19. self.downsample = [
  20. Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
  21. BatchNorm(self.expansion*planes)
  22. ]
  23. def __call__(self, x):
  24. out = self.bn1(self.conv1(x)).relu()
  25. out = self.bn2(self.conv2(out))
  26. out = out + x.sequential(self.downsample)
  27. out = out.relu()
  28. return out
  29. class Bottleneck:
  30. # NOTE: stride_in_1x1=False, this is the v1.5 variant
  31. expansion = 4
  32. def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
  33. width = int(planes * (base_width / 64.0)) * groups
  34. # NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
  35. self.conv1 = Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
  36. self.bn1 = BatchNorm(width)
  37. self.conv2 = Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
  38. self.bn2 = BatchNorm(width)
  39. self.conv3 = Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
  40. self.bn3 = BatchNorm(self.expansion*planes)
  41. self.downsample = []
  42. if stride != 1 or in_planes != self.expansion*planes:
  43. self.downsample = [
  44. Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
  45. BatchNorm(self.expansion*planes)
  46. ]
  47. def __call__(self, x):
  48. out = self.bn1(self.conv1(x)).relu()
  49. out = self.bn2(self.conv2(out)).relu()
  50. out = self.bn3(self.conv3(out))
  51. out = out + x.sequential(self.downsample)
  52. out = out.relu()
  53. return out
  54. class ResNet:
  55. def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
  56. self.num = num
  57. self.block = {
  58. 18: BasicBlock,
  59. 34: BasicBlock,
  60. 50: Bottleneck,
  61. 101: Bottleneck,
  62. 152: Bottleneck
  63. }[num]
  64. self.num_blocks = {
  65. 18: [2,2,2,2],
  66. 34: [3,4,6,3],
  67. 50: [3,4,6,3],
  68. 101: [3,4,23,3],
  69. 152: [3,8,36,3]
  70. }[num]
  71. self.in_planes = 64
  72. self.groups = groups
  73. self.base_width = width_per_group
  74. self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
  75. self.bn1 = BatchNorm(64)
  76. self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
  77. self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
  78. self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
  79. self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
  80. self.fc = Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
  81. def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
  82. strides = [stride] + [1] * (num_blocks-1)
  83. layers = []
  84. for stride in strides:
  85. if block == Bottleneck:
  86. layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
  87. else:
  88. layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
  89. self.in_planes = planes * block.expansion
  90. return layers
  91. def forward(self, x):
  92. is_feature_only = self.fc is None
  93. if is_feature_only: features = []
  94. out = self.bn1(self.conv1(x)).relu()
  95. out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
  96. out = out.sequential(self.layer1)
  97. if is_feature_only: features.append(out)
  98. out = out.sequential(self.layer2)
  99. if is_feature_only: features.append(out)
  100. out = out.sequential(self.layer3)
  101. if is_feature_only: features.append(out)
  102. out = out.sequential(self.layer4)
  103. if is_feature_only: features.append(out)
  104. if not is_feature_only:
  105. out = out.mean([2,3])
  106. out = self.fc(out.cast(dtypes.float32))
  107. return out
  108. return features
  109. def __call__(self, x:Tensor) -> Tensor:
  110. return self.forward(x)
  111. def load_from_pretrained(self):
  112. # TODO replace with fake torch load
  113. model_urls = {
  114. (18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
  115. (34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
  116. (50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
  117. (50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
  118. (101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
  119. (152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
  120. }
  121. self.url = model_urls[(self.num, self.groups, self.base_width)]
  122. for k, v in torch_load(fetch(self.url)).items():
  123. obj: Tensor = get_child(self, k)
  124. dat = v.numpy()
  125. if 'fc.' in k and obj.shape != dat.shape:
  126. print("skipping fully connected layer")
  127. continue # Skip FC if transfer learning
  128. if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
  129. obj.assign(dat.reshape(obj.shape))
  130. ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
  131. ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
  132. ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
  133. ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
  134. ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
  135. ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)