efficientnet.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import math
  2. from tinygrad.tensor import Tensor
  3. from tinygrad.nn import BatchNorm2d
  4. from tinygrad.helpers import get_child, fetch
  5. from tinygrad.nn.state import torch_load
  6. class MBConvBlock:
  7. def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
  8. oup = expand_ratio * input_filters
  9. if expand_ratio != 1:
  10. self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
  11. self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats)
  12. else:
  13. self._expand_conv = None
  14. self.strides = strides
  15. if strides == (2,2):
  16. self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2
  17. else:
  18. self.pad = [(kernel_size-1)//2]*4
  19. self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size)
  20. self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats)
  21. self.has_se = has_se
  22. if self.has_se:
  23. num_squeezed_channels = max(1, int(input_filters * se_ratio))
  24. self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1)
  25. self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
  26. self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1)
  27. self._se_expand_bias = Tensor.zeros(oup)
  28. self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1)
  29. self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats)
  30. def __call__(self, inputs):
  31. x = inputs
  32. if self._expand_conv is not None:
  33. x = self._bn0(x.conv2d(self._expand_conv)).swish()
  34. x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
  35. x = self._bn1(x).swish()
  36. if self.has_se:
  37. x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
  38. x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
  39. x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
  40. x = x.mul(x_squeezed.sigmoid())
  41. x = self._bn2(x.conv2d(self._project_conv))
  42. if x.shape == inputs.shape:
  43. x = x.add(inputs)
  44. return x
  45. class EfficientNet:
  46. def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
  47. self.number = number
  48. global_params = [
  49. # width, depth
  50. (1.0, 1.0), # b0
  51. (1.0, 1.1), # b1
  52. (1.1, 1.2), # b2
  53. (1.2, 1.4), # b3
  54. (1.4, 1.8), # b4
  55. (1.6, 2.2), # b5
  56. (1.8, 2.6), # b6
  57. (2.0, 3.1), # b7
  58. (2.2, 3.6), # b8
  59. (4.3, 5.3), # l2
  60. ][max(number,0)]
  61. def round_filters(filters):
  62. multiplier = global_params[0]
  63. divisor = 8
  64. filters *= multiplier
  65. new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
  66. if new_filters < 0.9 * filters: # prevent rounding by more than 10%
  67. new_filters += divisor
  68. return int(new_filters)
  69. def round_repeats(repeats):
  70. return int(math.ceil(global_params[1] * repeats))
  71. out_channels = round_filters(32)
  72. self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3)
  73. self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
  74. blocks_args = [
  75. [1, 3, (1,1), 1, 32, 16, 0.25],
  76. [2, 3, (2,2), 6, 16, 24, 0.25],
  77. [2, 5, (2,2), 6, 24, 40, 0.25],
  78. [3, 3, (2,2), 6, 40, 80, 0.25],
  79. [3, 5, (1,1), 6, 80, 112, 0.25],
  80. [4, 5, (2,2), 6, 112, 192, 0.25],
  81. [1, 3, (1,1), 6, 192, 320, 0.25],
  82. ]
  83. if self.number == -1:
  84. blocks_args = [
  85. [1, 3, (2,2), 1, 32, 40, 0.25],
  86. [1, 3, (2,2), 1, 40, 80, 0.25],
  87. [1, 3, (2,2), 1, 80, 192, 0.25],
  88. [1, 3, (2,2), 1, 192, 320, 0.25],
  89. ]
  90. elif self.number == -2:
  91. blocks_args = [
  92. [1, 9, (8,8), 1, 32, 320, 0.25],
  93. ]
  94. self._blocks = []
  95. for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
  96. input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
  97. for n in range(round_repeats(num_repeats)):
  98. self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
  99. input_filters = output_filters
  100. strides = (1,1)
  101. in_channels = round_filters(320)
  102. out_channels = round_filters(1280)
  103. self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1)
  104. self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
  105. if has_fc_output:
  106. self._fc = Tensor.glorot_uniform(out_channels, classes)
  107. self._fc_bias = Tensor.zeros(classes)
  108. else:
  109. self._fc = None
  110. def forward(self, x):
  111. x = self._bn0(x.conv2d(self._conv_stem, padding=(0,1,0,1), stride=2)).swish()
  112. x = x.sequential(self._blocks)
  113. x = self._bn1(x.conv2d(self._conv_head)).swish()
  114. x = x.avg_pool2d(kernel_size=x.shape[2:4])
  115. x = x.reshape(shape=(-1, x.shape[1]))
  116. return x.linear(self._fc, self._fc_bias) if self._fc is not None else x
  117. def load_from_pretrained(self):
  118. model_urls = {
  119. 0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
  120. 1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
  121. 2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
  122. 3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
  123. 4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
  124. 5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
  125. 6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
  126. 7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
  127. }
  128. b0 = torch_load(fetch(model_urls[self.number]))
  129. for k,v in b0.items():
  130. if k.endswith("num_batches_tracked"): continue
  131. for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
  132. if cat in k:
  133. k = k.replace('.bias', '_bias')
  134. k = k.replace('.weight', '')
  135. #print(k, v.shape)
  136. mv:Tensor = get_child(self, k)
  137. vnp = v #.astype(np.float32)
  138. vnp = vnp if k != '_fc' else vnp.clang().T
  139. #vnp = vnp if vnp.shape != () else np.array([vnp])
  140. if mv.shape == vnp.shape:
  141. mv.replace(vnp.to(mv.device))
  142. else:
  143. print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))