waifu2x.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Implementation of waifu2x vgg7 in tinygrad.
  2. # Obviously, not developed, supported, etc. by the original waifu2x author(s).
  3. import numpy
  4. from tinygrad.tensor import Tensor
  5. from PIL import Image
  6. from tinygrad.helpers import fetch
  7. # File Formats
  8. # tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
  9. # tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
  10. def image_load(path) -> numpy.ndarray:
  11. """
  12. Loads an image in the shape expected by other functions in this module.
  13. Doesn't Tensor it, in case you need to do further work with it.
  14. """
  15. # file
  16. na = numpy.array(Image.open(path))
  17. if na.shape[2] == 4:
  18. # RGBA -> RGB (covers opaque images with alpha channels)
  19. na = na[:,:,0:3]
  20. # fix shape
  21. na = numpy.moveaxis(na, [2,0,1], [0,1,2])
  22. # shape is now (3,h,w), add 1
  23. na = na.reshape(1,3,na.shape[1],na.shape[2])
  24. # change type
  25. na = na.astype("float32") / 255.0
  26. return na
  27. def image_save(path, na: numpy.ndarray):
  28. """
  29. Saves an image of the shape expected by other functions in this module.
  30. However, note this expects a numpy array.
  31. """
  32. # change type
  33. na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8")
  34. # shape is now (1,3,h,w), remove 1
  35. na = na.reshape(3,na.shape[2],na.shape[3])
  36. # fix shape
  37. na = numpy.moveaxis(na, [0,1,2], [2,0,1])
  38. # shape is now (h,w,3)
  39. # file
  40. Image.fromarray(na).save(path)
  41. # The Model
  42. class Conv3x3Biased:
  43. """
  44. A 3x3 convolution layer with some utility functions.
  45. """
  46. def __init__(self, inC, outC, last = False):
  47. # The properties must be named as "W" and "b".
  48. # This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
  49. # though this cannot necessarily account for transposition and other such things.
  50. # Massively overstate the weights to get them to be focused on,
  51. # since otherwise the biases overrule everything
  52. self.W = Tensor.uniform(outC, inC, 3, 3) * 16.0
  53. # Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is?
  54. # Values-wise, entirely different blatant cheat.
  55. # In most cases, use uniform bias, but tiny.
  56. # For the last layer, use just 0.5, constant.
  57. if last:
  58. self.b = Tensor.zeros(1, outC, 1, 1) + 0.5
  59. else:
  60. self.b = Tensor.uniform(1, outC, 1, 1)
  61. def forward(self, x):
  62. # You might be thinking, "but what about padding?"
  63. # Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it.
  64. return x.conv2d(self.W).add(self.b)
  65. def get_parameters(self) -> list:
  66. return [self.W, self.b]
  67. def load_waifu2x_json(self, layer: dict):
  68. # Weights in this file are outChannel,inChannel,X,Y.
  69. # Not outChannel,inChannel,Y,X.
  70. # Therefore, transpose it before assignment.
  71. # I have long since forgotten how I worked this out.
  72. self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3))
  73. self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
  74. class Vgg7:
  75. """
  76. The 'vgg7' waifu2x network.
  77. Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test.
  78. """
  79. def __init__(self):
  80. self.conv1 = Conv3x3Biased(3, 32)
  81. self.conv2 = Conv3x3Biased(32, 32)
  82. self.conv3 = Conv3x3Biased(32, 64)
  83. self.conv4 = Conv3x3Biased(64, 64)
  84. self.conv5 = Conv3x3Biased(64, 128)
  85. self.conv6 = Conv3x3Biased(128, 128)
  86. self.conv7 = Conv3x3Biased(128, 3, True)
  87. def forward(self, x):
  88. """
  89. Forward pass: Actually runs the network.
  90. Input format: (1, 3, Y, X)
  91. Output format: (1, 3, Y - 14, X - 14)
  92. (the - 14 represents the 7-pixel context border that is lost)
  93. """
  94. x = self.conv1.forward(x).leakyrelu(0.1)
  95. x = self.conv2.forward(x).leakyrelu(0.1)
  96. x = self.conv3.forward(x).leakyrelu(0.1)
  97. x = self.conv4.forward(x).leakyrelu(0.1)
  98. x = self.conv5.forward(x).leakyrelu(0.1)
  99. x = self.conv6.forward(x).leakyrelu(0.1)
  100. x = self.conv7.forward(x)
  101. return x
  102. def get_parameters(self) -> list:
  103. return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
  104. def load_from_pretrained(self, intent = "art", subtype = "scale2.0x"):
  105. """
  106. Downloads a nagadomi/waifu2x JSON weight file and loads it.
  107. """
  108. import json
  109. data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes())
  110. self.load_waifu2x_json(data)
  111. def load_waifu2x_json(self, data: list):
  112. """
  113. Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
  114. data (passed in) is assumed to be the output of json.load or some similar on such a file
  115. """
  116. self.conv1.load_waifu2x_json(data[0])
  117. self.conv2.load_waifu2x_json(data[1])
  118. self.conv3.load_waifu2x_json(data[2])
  119. self.conv4.load_waifu2x_json(data[3])
  120. self.conv5.load_waifu2x_json(data[4])
  121. self.conv6.load_waifu2x_json(data[5])
  122. self.conv7.load_waifu2x_json(data[6])
  123. def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
  124. """
  125. Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.
  126. Note that you really shouldn't try to run anything not (1, 3, *, *) through this.
  127. """
  128. # Constant that only really gets repeated a ton here.
  129. context = 7
  130. context2 = context + context
  131. # Notably, numpy is used here because it makes this fine manipulation a lot simpler.
  132. # Scaling first - repeat on axis 2 and axis 3 (Y & X)
  133. image = image.repeat(2, 2).repeat(2, 3)
  134. # Resulting image buffer. This is made before the input is padded,
  135. # since the input has the padded shape right now.
  136. image_out = numpy.zeros(image.shape)
  137. # Padding next. Note that this padding is done on the whole image.
  138. # Padding the tiles would lose critical context, cause seams, etc.
  139. image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
  140. # Now for tiling.
  141. # The output tile size is the usable output from an input tile (tile_size).
  142. # As such, the tiles overlap.
  143. out_tile_size = tile_size - context2
  144. for out_y in range(0, image_out.shape[2], out_tile_size):
  145. for out_x in range(0, image_out.shape[3], out_tile_size):
  146. # Input is sourced from the same coordinates, but some stuff ought to be
  147. # noted here for future reference:
  148. # + out_x/y's equivalent position w/ the padding is out_x + context.
  149. # + The output, however, is without context. Input needs context.
  150. # + Therefore, the input rectangle is expanded on all sides by context.
  151. # + Therefore, the input position has the context subtracted again.
  152. # + Therefore:
  153. in_y = out_y
  154. in_x = out_x
  155. # not shown: in_w/in_h = tile_size (as opposed to out_tile_size)
  156. # Extract tile.
  157. # Note that numpy will auto-crop this at the bottom-right.
  158. # This will never be a problem, as tiles are specifically chosen within the padded section.
  159. tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size]
  160. # Extracted tile dimensions -> output dimensions
  161. # This is important because of said cropping, otherwise it'd be interior tile size.
  162. out_h = tile.shape[2] - context2
  163. out_w = tile.shape[3] - context2
  164. # Process tile.
  165. tile_t = Tensor(tile)
  166. tile_fwd_t = self.forward(tile_t)
  167. # Replace tile.
  168. image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
  169. return image_out