inception.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from tinygrad import Tensor
  2. from tinygrad.nn import Conv2d, BatchNorm2d, Linear
  3. from tinygrad.nn.state import load_state_dict, torch_load
  4. from tinygrad.helpers import fetch
  5. from typing import Optional, Dict
  6. # Base Inception Model
  7. class BasicConv2d:
  8. def __init__(self, in_ch:int, out_ch:int, **kwargs):
  9. self.conv = Conv2d(in_ch, out_ch, bias=False, **kwargs)
  10. self.bn = BatchNorm2d(out_ch, eps=0.001)
  11. def __call__(self, x:Tensor) -> Tensor:
  12. return x.sequential([self.conv, self.bn, Tensor.relu])
  13. class InceptionA:
  14. def __init__(self, in_ch:int, pool_feat:int):
  15. self.branch1x1 = BasicConv2d(in_ch, 64, kernel_size=1)
  16. self.branch5x5_1 = BasicConv2d(in_ch, 48, kernel_size=1)
  17. self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
  18. self.branch3x3dbl_1 = BasicConv2d(in_ch, 64, kernel_size=1)
  19. self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=(3,3), padding=1)
  20. self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=(3,3), padding=1)
  21. self.branch_pool = BasicConv2d(in_ch, pool_feat, kernel_size=1)
  22. def __call__(self, x:Tensor) -> Tensor:
  23. outputs = [
  24. self.branch1x1(x),
  25. x.sequential([self.branch5x5_1, self.branch5x5_2]),
  26. x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2, self.branch3x3dbl_3]),
  27. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1)),
  28. ]
  29. return Tensor.cat(*outputs, dim=1)
  30. class InceptionB:
  31. def __init__(self, in_ch:int):
  32. self.branch3x3 = BasicConv2d(in_ch, 384, kernel_size=(3,3), stride=2)
  33. self.branch3x3dbl_1 = BasicConv2d(in_ch, 64, kernel_size=1)
  34. self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=(3,3), padding=1)
  35. self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=(3,3), stride=2)
  36. def __call__(self, x:Tensor) -> Tensor:
  37. outputs = [
  38. self.branch3x3(x),
  39. x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2, self.branch3x3dbl_3]),
  40. x.max_pool2d(kernel_size=(3,3), stride=2, dilation=1),
  41. ]
  42. return Tensor.cat(*outputs, dim=1)
  43. class InceptionC:
  44. def __init__(self, in_ch, ch_7x7):
  45. self.branch1x1 = BasicConv2d(in_ch, 192, kernel_size=1)
  46. self.branch7x7_1 = BasicConv2d(in_ch, ch_7x7, kernel_size=1)
  47. self.branch7x7_2 = BasicConv2d(ch_7x7, ch_7x7, kernel_size=(1, 7), padding=(0, 3))
  48. self.branch7x7_3 = BasicConv2d(ch_7x7, 192, kernel_size=(7, 1), padding=(3, 0))
  49. self.branch7x7dbl_1 = BasicConv2d(in_ch, ch_7x7, kernel_size=1)
  50. self.branch7x7dbl_2 = BasicConv2d(ch_7x7, ch_7x7, kernel_size=(7, 1), padding=(3, 0))
  51. self.branch7x7dbl_3 = BasicConv2d(ch_7x7, ch_7x7, kernel_size=(1, 7), padding=(0, 3))
  52. self.branch7x7dbl_4 = BasicConv2d(ch_7x7, ch_7x7, kernel_size=(7, 1), padding=(3, 0))
  53. self.branch7x7dbl_5 = BasicConv2d(ch_7x7, 192, kernel_size=(1, 7), padding=(0, 3))
  54. self.branch_pool = BasicConv2d(in_ch, 192, kernel_size=1)
  55. def __call__(self, x:Tensor) -> Tensor:
  56. outputs = [
  57. self.branch1x1(x),
  58. x.sequential([self.branch7x7_1, self.branch7x7_2, self.branch7x7_3]),
  59. x.sequential([self.branch7x7dbl_1, self.branch7x7dbl_2, self.branch7x7dbl_3, self.branch7x7dbl_4, self.branch7x7dbl_5]),
  60. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1)),
  61. ]
  62. return Tensor.cat(*outputs, dim=1)
  63. class InceptionD:
  64. def __init__(self, in_ch:int):
  65. self.branch3x3_1 = BasicConv2d(in_ch, 192, kernel_size=1)
  66. self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=(3,3), stride=2)
  67. self.branch7x7x3_1 = BasicConv2d(in_ch, 192, kernel_size=1)
  68. self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
  69. self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
  70. self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=(3,3), stride=2)
  71. def __call__(self, x:Tensor) -> Tensor:
  72. outputs = [
  73. x.sequential([self.branch3x3_1, self.branch3x3_2]),
  74. x.sequential([self.branch7x7x3_1, self.branch7x7x3_2, self.branch7x7x3_3, self.branch7x7x3_4]),
  75. x.max_pool2d(kernel_size=(3,3), stride=2, dilation=1),
  76. ]
  77. return Tensor.cat(*outputs, dim=1)
  78. class InceptionE:
  79. def __init__(self, in_ch:int):
  80. self.branch1x1 = BasicConv2d(in_ch, 320, kernel_size=1)
  81. self.branch3x3_1 = BasicConv2d(in_ch, 384, kernel_size=1)
  82. self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
  83. self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
  84. self.branch3x3dbl_1 = BasicConv2d(in_ch, 448, kernel_size=1)
  85. self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=(3,3), padding=1)
  86. self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
  87. self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
  88. self.branch_pool = BasicConv2d(in_ch, 192, kernel_size=1)
  89. def __call__(self, x:Tensor) -> Tensor:
  90. branch3x3 = self.branch3x3_1(x)
  91. branch3x3dbl = x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2])
  92. outputs = [
  93. self.branch1x1(x),
  94. Tensor.cat(self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), dim=1),
  95. Tensor.cat(self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), dim=1),
  96. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1)),
  97. ]
  98. return Tensor.cat(*outputs, dim=1)
  99. class InceptionAux:
  100. def __init__(self, in_ch:int, num_classes:int):
  101. self.conv0 = BasicConv2d(in_ch, 128, kernel_size=1)
  102. self.conv1 = BasicConv2d(128, 768, kernel_size=5)
  103. self.fc = Linear(768, num_classes)
  104. def __call__(self, x:Tensor) -> Tensor:
  105. x = x.avg_pool2d(kernel_size=5, stride=3, padding=1).sequential([self.conv0, self.conv1])
  106. x = x.avg_pool2d(kernel_size=1, padding=1).reshape(x.shape[0],-1)
  107. return self.fc(x)
  108. class Inception3:
  109. def __init__(self, num_classes:int=1008, cls_map:Optional[Dict]=None):
  110. def get_cls(key1:str, key2:str, default):
  111. return default if cls_map is None else cls_map.get(key1, cls_map.get(key2, default))
  112. self.transform_input = False
  113. self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=(3,3), stride=2)
  114. self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=(3,3))
  115. self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=(3,3), padding=1)
  116. self.maxpool1 = lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=2, padding=1)
  117. self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
  118. self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=(3,3))
  119. self.maxpool2 = lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=2, padding=1)
  120. self.Mixed_5b = get_cls("A1","A",InceptionA)(192, pool_feat=32)
  121. self.Mixed_5c = get_cls("A2","A",InceptionA)(256, pool_feat=64)
  122. self.Mixed_5d = get_cls("A3","A",InceptionA)(288, pool_feat=64)
  123. self.Mixed_6a = get_cls("B1","B",InceptionB)(288)
  124. self.Mixed_6b = get_cls("C1","C",InceptionC)(768, ch_7x7=128)
  125. self.Mixed_6c = get_cls("C2","C",InceptionC)(768, ch_7x7=160)
  126. self.Mixed_6d = get_cls("C3","C",InceptionC)(768, ch_7x7=160)
  127. self.Mixed_6e = get_cls("C4","C",InceptionC)(768, ch_7x7=192)
  128. self.Mixed_7a = get_cls("D1","D",InceptionD)(768)
  129. self.Mixed_7b = get_cls("E1","E",InceptionE)(1280)
  130. self.Mixed_7c = get_cls("E2","E",InceptionE)(2048)
  131. self.avgpool = lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8), padding=1)
  132. self.fc = Linear(2048, num_classes)
  133. def __call__(self, x:Tensor) -> Tensor:
  134. return x.sequential([
  135. self.Conv2d_1a_3x3,
  136. self.Conv2d_2a_3x3,
  137. self.Conv2d_2b_3x3,
  138. self.maxpool1,
  139. self.Conv2d_3b_1x1,
  140. self.Conv2d_4a_3x3,
  141. self.maxpool2,
  142. self.Mixed_5b,
  143. self.Mixed_5c,
  144. self.Mixed_5d,
  145. self.Mixed_6a,
  146. self.Mixed_6b,
  147. self.Mixed_6c,
  148. self.Mixed_6d,
  149. self.Mixed_6e,
  150. self.Mixed_7a,
  151. self.Mixed_7b,
  152. self.Mixed_7c,
  153. self.avgpool,
  154. lambda y: y.reshape(x.shape[0],-1),
  155. self.fc,
  156. ])
  157. # FID Inception Variation
  158. class FidInceptionA(InceptionA):
  159. def __call__(self, x:Tensor) -> Tensor:
  160. outputs = [
  161. self.branch1x1(x),
  162. x.sequential([self.branch5x5_1, self.branch5x5_2]),
  163. x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2, self.branch3x3dbl_3]),
  164. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1, count_include_pad=False))
  165. ]
  166. return Tensor.cat(*outputs, dim=1)
  167. class FidInceptionC(InceptionC):
  168. def __call__(self, x:Tensor) -> Tensor:
  169. outputs = [
  170. self.branch1x1(x),
  171. x.sequential([self.branch7x7_1, self.branch7x7_2, self.branch7x7_3]),
  172. x.sequential([self.branch7x7dbl_1, self.branch7x7dbl_2, self.branch7x7dbl_3, self.branch7x7dbl_4, self.branch7x7dbl_5]),
  173. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1, count_include_pad=False))
  174. ]
  175. return Tensor.cat(*outputs, dim=1)
  176. class FidInceptionE1(InceptionE):
  177. def __call__(self, x:Tensor) -> Tensor:
  178. branch3x3 = self.branch3x3_1(x)
  179. branch3x3dbl = x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2])
  180. outputs = [
  181. self.branch1x1(x),
  182. Tensor.cat(self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), dim=1),
  183. Tensor.cat(self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), dim=1),
  184. self.branch_pool(x.avg_pool2d(kernel_size=(3,3), stride=1, padding=1, count_include_pad=False)),
  185. ]
  186. return Tensor.cat(*outputs, dim=1)
  187. class FidInceptionE2(InceptionE):
  188. def __call__(self, x:Tensor) -> Tensor:
  189. branch3x3 = self.branch3x3_1(x)
  190. branch3x3dbl = x.sequential([self.branch3x3dbl_1, self.branch3x3dbl_2])
  191. outputs = [
  192. self.branch1x1(x),
  193. Tensor.cat(self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), dim=1),
  194. Tensor.cat(self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), dim=1),
  195. self.branch_pool(x.max_pool2d(kernel_size=(3,3), stride=1, padding=1)),
  196. ]
  197. return Tensor.cat(*outputs, dim=1)
  198. class FidInceptionV3:
  199. def __init__(self):
  200. inception = Inception3(cls_map={
  201. "A": FidInceptionA,
  202. "C": FidInceptionC,
  203. "E1": FidInceptionE1,
  204. "E2": FidInceptionE2,
  205. })
  206. self.Conv2d_1a_3x3 = inception.Conv2d_1a_3x3
  207. self.Conv2d_2a_3x3 = inception.Conv2d_2a_3x3
  208. self.Conv2d_2b_3x3 = inception.Conv2d_2b_3x3
  209. self.Conv2d_3b_1x1 = inception.Conv2d_3b_1x1
  210. self.Conv2d_4a_3x3 = inception.Conv2d_4a_3x3
  211. self.Mixed_5b = inception.Mixed_5b
  212. self.Mixed_5c = inception.Mixed_5c
  213. self.Mixed_5d = inception.Mixed_5d
  214. self.Mixed_6a = inception.Mixed_6a
  215. self.Mixed_6b = inception.Mixed_6b
  216. self.Mixed_6c = inception.Mixed_6c
  217. self.Mixed_6d = inception.Mixed_6d
  218. self.Mixed_6e = inception.Mixed_6e
  219. self.Mixed_7a = inception.Mixed_7a
  220. self.Mixed_7b = inception.Mixed_7b
  221. self.Mixed_7c = inception.Mixed_7c
  222. def load_from_pretrained(self):
  223. state_dict = torch_load(str(fetch("https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth", "pt_inception-2015-12-05-6726825d.pth")))
  224. for k,v in state_dict.items():
  225. if k.endswith(".num_batches_tracked"):
  226. state_dict[k] = v.reshape(1)
  227. load_state_dict(self, state_dict)
  228. return self
  229. def __call__(self, x:Tensor) -> Tensor:
  230. x = x.interpolate((299,299), mode="linear")
  231. x = (x * 2) - 1
  232. x = x.sequential([
  233. self.Conv2d_1a_3x3,
  234. self.Conv2d_2a_3x3,
  235. self.Conv2d_2b_3x3,
  236. lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=2, dilation=1),
  237. self.Conv2d_3b_1x1,
  238. self.Conv2d_4a_3x3,
  239. lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=2, dilation=1),
  240. self.Mixed_5b,
  241. self.Mixed_5c,
  242. self.Mixed_5d,
  243. self.Mixed_6a,
  244. self.Mixed_6b,
  245. self.Mixed_6c,
  246. self.Mixed_6d,
  247. self.Mixed_6e,
  248. self.Mixed_7a,
  249. self.Mixed_7b,
  250. self.Mixed_7c,
  251. lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8)),
  252. ])
  253. return x