vgg7.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import sys
  2. import random
  3. import json
  4. import numpy
  5. from pathlib import Path
  6. from PIL import Image
  7. from tinygrad.tensor import Tensor
  8. from tinygrad.nn.optim import SGD
  9. from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
  10. from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
  11. # amount of context erased by model
  12. CONTEXT = 7
  13. def get_sample_count(samples_dir):
  14. try:
  15. samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
  16. v = samples_dir_count_file.readline()
  17. samples_dir_count_file.close()
  18. return int(v)
  19. except:
  20. return 0
  21. def set_sample_count(samples_dir, sc):
  22. with open(samples_dir + "/sample_count.txt", "w") as file:
  23. file.write(str(sc) + "\n")
  24. if len(sys.argv) < 2:
  25. print("python3 -m examples.vgg7 import MODELJSON MODEL")
  26. print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
  27. print(" into a safetensors file")
  28. print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
  29. print(" *this format is used by most other commands in this program*")
  30. print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
  31. print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
  32. print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
  33. print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
  34. print(" output image has 7 pixels removed on all edges")
  35. print(" do not run on large images, will have *hilarious* RAM use")
  36. print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT")
  37. print(" does the 'whole thing' (padding, tiling)")
  38. print(" safe for large images, etc.")
  39. print("python3 -m examples.vgg7 new MODEL")
  40. print(" creates a new model (experimental)")
  41. print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE")
  42. print(" trains a model (experimental)")
  43. print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
  44. print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
  45. print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
  46. print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
  47. print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
  48. print(" my_samples/0b.png is the first original image)")
  49. print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
  50. print(" won't pad or tile, so keep image sizes sane")
  51. print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
  52. print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
  53. print(" maintains/creates samples_count.txt automatically")
  54. print(" unlike training, IMG_A must be exactly half the size of IMG_B")
  55. sys.exit(1)
  56. cmd = sys.argv[1]
  57. vgg7 = Vgg7()
  58. def nansbane(p):
  59. if numpy.isnan(numpy.min(p.numpy())):
  60. raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
  61. def load_and_save(path, save):
  62. if save:
  63. for v in vgg7.get_parameters():
  64. nansbane(v)
  65. st = get_state_dict(vgg7)
  66. safe_save(st, path)
  67. else:
  68. st = safe_load(path)
  69. load_state_dict(vgg7, st)
  70. for v in vgg7.get_parameters():
  71. nansbane(v)
  72. if cmd == "import":
  73. src = sys.argv[2]
  74. model = sys.argv[3]
  75. vgg7.load_waifu2x_json(json.load(open(src, "rb")))
  76. load_and_save(model, True)
  77. elif cmd == "import_kinne":
  78. # tinygrad wasn't doing safetensors when this example was written
  79. # it's possible someone might have a model around using the resulting interim format
  80. src = sys.argv[2]
  81. model = sys.argv[3]
  82. index = 0
  83. for t in vgg7.get_parameters():
  84. fn = src + "/snoop_bin_" + str(index) + ".bin"
  85. t.assign(Tensor(numpy.fromfile(fn, "<f4")).reshape(shape=t.shape))
  86. index += 1
  87. load_and_save(model, True)
  88. elif cmd == "execute":
  89. model = sys.argv[2]
  90. in_file = sys.argv[3]
  91. out_file = sys.argv[4]
  92. load_and_save(model, False)
  93. image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
  94. elif cmd == "execute_full":
  95. model = sys.argv[2]
  96. in_file = sys.argv[3]
  97. out_file = sys.argv[4]
  98. load_and_save(model, False)
  99. image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
  100. elif cmd == "new":
  101. model = sys.argv[2]
  102. load_and_save(model, True)
  103. elif cmd == "train":
  104. model = sys.argv[2]
  105. samples_base = sys.argv[3]
  106. samples_count = get_sample_count(samples_base)
  107. rounds = int(sys.argv[4])
  108. rounds_per_save = int(sys.argv[5])
  109. load_and_save(model, False)
  110. # Initialize sample probabilities.
  111. # This is used to try and get the network to focus on "interesting" samples,
  112. # which works nicely with the microsample system.
  113. sample_probs = None
  114. sample_probs_path = model + "_sample_probs.bin"
  115. try:
  116. # try to read...
  117. sample_probs = numpy.fromfile(sample_probs_path, "<f8")
  118. if sample_probs.shape[0] != samples_count:
  119. print("sample probs size != sample count - initializing")
  120. sample_probs = None
  121. except:
  122. # it's fine
  123. print("sample probs could not be loaded - initializing")
  124. if sample_probs is None:
  125. # This stupidly high amount is used to force an initial pass over all samples
  126. sample_probs = numpy.ones(samples_count) * 1000
  127. print("Training...")
  128. # Adam has a tendency to destroy the state of the network when restarted
  129. # Plus it's slower
  130. optim = SGD(vgg7.get_parameters())
  131. rnum = 0
  132. while True:
  133. # The way the -1 option works is that rnum is never -1.
  134. if rnum == rounds:
  135. break
  136. sample_idx = 0
  137. try:
  138. sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
  139. except:
  140. print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
  141. sample_idx = random.randint(0, samples_count - 1)
  142. x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
  143. y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
  144. sample_x = Tensor(x_img, requires_grad = False)
  145. sample_y = Tensor(y_img, requires_grad = False)
  146. # magic code roughly from readme example
  147. # An explanation, in case anyone else has to go down this path:
  148. # This runs the actual network normally
  149. out = vgg7.forward(sample_x)
  150. # Subtraction determines error here (as this is an image, not classification).
  151. # *Abs is the important bit* - at least for me, anyway.
  152. # The training process seeks to minimize this 'loss' value.
  153. # Minimization of loss *tends towards negative infinity*, so without the abs,
  154. # or without an implicit abs (the mul in the README),
  155. # loss will always go haywire in one direction or another.
  156. # Mean determines how errors are treated.
  157. # Do not use Sum. I tried that. It worked while I was using 1x1 patches...
  158. # Then it went exponential.
  159. # Also, Mean goes *after* abs. I realize this should have been obvious to me.
  160. loss = sample_y.sub(out).abs().mean()
  161. # This is the bit where tinygrad works backward from the loss
  162. optim.zero_grad()
  163. loss.backward()
  164. # And this updates the parameters
  165. optim.step()
  166. # warning: used by sample probability adjuster
  167. loss_indicator = loss.max().numpy()
  168. print("Round " + str(rnum) + " : " + str(loss_indicator))
  169. if (rnum % rounds_per_save) == 0:
  170. print("Saving")
  171. load_and_save(model, True)
  172. sample_probs.astype("<f8", "C").tofile(sample_probs_path)
  173. # Update round state
  174. # Number
  175. rnum = rnum + 1
  176. # Probability management
  177. # there must always be a probability, no matter how slim, even if loss goes to 0
  178. sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
  179. # if we were told to save every round, we already saved
  180. if rounds_per_save != 1:
  181. print("Done with all rounds, saving")
  182. load_and_save(model, True)
  183. sample_probs.astype("<f8", "C").tofile(sample_probs_path)
  184. elif cmd == "samplify":
  185. a_img = sys.argv[2]
  186. b_img = sys.argv[3]
  187. samples_base = sys.argv[4]
  188. sample_size = int(sys.argv[5])
  189. samples_count = get_sample_count(samples_base)
  190. # This bit is interesting because it actually does some work.
  191. # Not much, but some work.
  192. a_img = image_load(a_img)
  193. b_img = image_load(b_img)
  194. # as with the main library body,
  195. # Y X order is used here
  196. # assertion before pre-upscaling is performed
  197. assert a_img.shape[2] == (b_img.shape[2] // 2)
  198. assert a_img.shape[3] == (b_img.shape[3] // 2)
  199. # pre-upscaling - this matches the sizes (and coordinates)
  200. a_img = a_img.repeat(2, 2).repeat(2, 3)
  201. samples_added = 0
  202. # actual patch extraction
  203. for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
  204. for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
  205. # this is a viable patch location, add it
  206. # note the ranges here:
  207. # + there are always CONTEXT pixels *before* the point
  208. # + with no subtraction at the end, there'd already be a pixel *at* the point,
  209. # as ranges are exclusive
  210. # + additionally, there are sample_size - 1 additional sample pixels
  211. # + additionally, there are CONTEXT additional pixels
  212. # + therefore there are CONTEXT + sample_size pixels *at & after* the point
  213. patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
  214. patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
  215. image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
  216. image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
  217. samples_count += 1
  218. samples_added += 1
  219. print(f"Added {str(samples_added)} samples")
  220. set_sample_count(samples_base, samples_count)
  221. else:
  222. print("unknown command")