utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import queue
  6. import threading
  7. import torch
  8. from basicsr.utils.download_util import load_file_from_url
  9. from torch.nn import functional as F
  10. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  11. class RealESRGANer():
  12. """A helper class for upsampling images with RealESRGAN.
  13. Args:
  14. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  15. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  16. model (nn.Module): The defined network. Default: None.
  17. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  18. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  19. 0 denotes for do not use tile. Default: 0.
  20. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  21. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  22. half (float): Whether to use half precision during inference. Default: False.
  23. """
  24. def __init__(self,
  25. scale,
  26. model_path,
  27. dni_weight=None,
  28. model=None,
  29. tile=0,
  30. tile_pad=10,
  31. pre_pad=10,
  32. half=False,
  33. device=None,
  34. gpu_id=None):
  35. self.scale = scale
  36. self.tile_size = tile
  37. self.tile_pad = tile_pad
  38. self.pre_pad = pre_pad
  39. self.mod_scale = None
  40. self.half = half
  41. # initialize model
  42. if gpu_id:
  43. self.device = torch.device(
  44. f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
  45. else:
  46. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
  47. if isinstance(model_path, list):
  48. # dni
  49. assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
  50. loadnet = self.dni(model_path[0], model_path[1], dni_weight)
  51. else:
  52. # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
  53. if model_path.startswith('https://'):
  54. model_path = load_file_from_url(
  55. url=model_path,
  56. model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
  57. progress=True,
  58. file_name=None)
  59. loadnet = torch.load(model_path, map_location=torch.device('cpu'))
  60. # prefer to use params_ema
  61. if 'params_ema' in loadnet:
  62. keyname = 'params_ema'
  63. else:
  64. keyname = 'params'
  65. model.load_state_dict(loadnet[keyname], strict=True)
  66. model.eval()
  67. self.model = model.to(self.device)
  68. if self.half:
  69. self.model = self.model.half()
  70. def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
  71. """Deep network interpolation.
  72. ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
  73. """
  74. net_a = torch.load(net_a, map_location=torch.device(loc))
  75. net_b = torch.load(net_b, map_location=torch.device(loc))
  76. for k, v_a in net_a[key].items():
  77. net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
  78. return net_a
  79. def pre_process(self, img):
  80. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  81. """
  82. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  83. self.img = img.unsqueeze(0).to(self.device)
  84. if self.half:
  85. self.img = self.img.half()
  86. # pre_pad
  87. if self.pre_pad != 0:
  88. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  89. # mod pad for divisible borders
  90. if self.scale == 2:
  91. self.mod_scale = 2
  92. elif self.scale == 1:
  93. self.mod_scale = 4
  94. if self.mod_scale is not None:
  95. self.mod_pad_h, self.mod_pad_w = 0, 0
  96. _, _, h, w = self.img.size()
  97. if (h % self.mod_scale != 0):
  98. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  99. if (w % self.mod_scale != 0):
  100. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  101. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  102. def process(self):
  103. # model inference
  104. self.output = self.model(self.img)
  105. def tile_process(self):
  106. """It will first crop input images to tiles, and then process each tile.
  107. Finally, all the processed tiles are merged into one images.
  108. Modified from: https://github.com/ata4/esrgan-launcher
  109. """
  110. batch, channel, height, width = self.img.shape
  111. output_height = height * self.scale
  112. output_width = width * self.scale
  113. output_shape = (batch, channel, output_height, output_width)
  114. # start with black image
  115. self.output = self.img.new_zeros(output_shape)
  116. tiles_x = math.ceil(width / self.tile_size)
  117. tiles_y = math.ceil(height / self.tile_size)
  118. # loop over all tiles
  119. for y in range(tiles_y):
  120. for x in range(tiles_x):
  121. # extract tile from input image
  122. ofs_x = x * self.tile_size
  123. ofs_y = y * self.tile_size
  124. # input tile area on total image
  125. input_start_x = ofs_x
  126. input_end_x = min(ofs_x + self.tile_size, width)
  127. input_start_y = ofs_y
  128. input_end_y = min(ofs_y + self.tile_size, height)
  129. # input tile area on total image with padding
  130. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  131. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  132. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  133. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  134. # input tile dimensions
  135. input_tile_width = input_end_x - input_start_x
  136. input_tile_height = input_end_y - input_start_y
  137. tile_idx = y * tiles_x + x + 1
  138. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  139. # upscale tile
  140. try:
  141. with torch.no_grad():
  142. output_tile = self.model(input_tile)
  143. except RuntimeError as error:
  144. print('Error', error)
  145. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  146. # output tile area on total image
  147. output_start_x = input_start_x * self.scale
  148. output_end_x = input_end_x * self.scale
  149. output_start_y = input_start_y * self.scale
  150. output_end_y = input_end_y * self.scale
  151. # output tile area without padding
  152. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  153. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  154. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  155. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  156. # put tile into output image
  157. self.output[:, :, output_start_y:output_end_y,
  158. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  159. output_start_x_tile:output_end_x_tile]
  160. def post_process(self):
  161. # remove extra pad
  162. if self.mod_scale is not None:
  163. _, _, h, w = self.output.size()
  164. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  165. # remove prepad
  166. if self.pre_pad != 0:
  167. _, _, h, w = self.output.size()
  168. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  169. return self.output
  170. @torch.no_grad()
  171. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  172. h_input, w_input = img.shape[0:2]
  173. # img: numpy
  174. img = img.astype(np.float32)
  175. if np.max(img) > 256: # 16-bit image
  176. max_range = 65535
  177. print('\tInput is a 16-bit image')
  178. else:
  179. max_range = 255
  180. img = img / max_range
  181. if len(img.shape) == 2: # gray image
  182. img_mode = 'L'
  183. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  184. elif img.shape[2] == 4: # RGBA image with alpha channel
  185. img_mode = 'RGBA'
  186. alpha = img[:, :, 3]
  187. img = img[:, :, 0:3]
  188. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  189. if alpha_upsampler == 'realesrgan':
  190. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  191. else:
  192. img_mode = 'RGB'
  193. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  194. # ------------------- process image (without the alpha channel) ------------------- #
  195. self.pre_process(img)
  196. if self.tile_size > 0:
  197. self.tile_process()
  198. else:
  199. self.process()
  200. output_img = self.post_process()
  201. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  202. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  203. if img_mode == 'L':
  204. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  205. # ------------------- process the alpha channel if necessary ------------------- #
  206. if img_mode == 'RGBA':
  207. if alpha_upsampler == 'realesrgan':
  208. self.pre_process(alpha)
  209. if self.tile_size > 0:
  210. self.tile_process()
  211. else:
  212. self.process()
  213. output_alpha = self.post_process()
  214. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  215. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  216. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  217. else: # use the cv2 resize for alpha channel
  218. h, w = alpha.shape[0:2]
  219. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  220. # merge the alpha channel
  221. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  222. output_img[:, :, 3] = output_alpha
  223. # ------------------------------ return ------------------------------ #
  224. if max_range == 65535: # 16-bit image
  225. output = (output_img * 65535.0).round().astype(np.uint16)
  226. else:
  227. output = (output_img * 255.0).round().astype(np.uint8)
  228. if outscale is not None and outscale != float(self.scale):
  229. output = cv2.resize(
  230. output, (
  231. int(w_input * outscale),
  232. int(h_input * outscale),
  233. ), interpolation=cv2.INTER_LANCZOS4)
  234. return output, img_mode
  235. class PrefetchReader(threading.Thread):
  236. """Prefetch images.
  237. Args:
  238. img_list (list[str]): A image list of image paths to be read.
  239. num_prefetch_queue (int): Number of prefetch queue.
  240. """
  241. def __init__(self, img_list, num_prefetch_queue):
  242. super().__init__()
  243. self.que = queue.Queue(num_prefetch_queue)
  244. self.img_list = img_list
  245. def run(self):
  246. for img_path in self.img_list:
  247. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  248. self.que.put(img)
  249. self.que.put(None)
  250. def __next__(self):
  251. next_item = self.que.get()
  252. if next_item is None:
  253. raise StopIteration
  254. return next_item
  255. def __iter__(self):
  256. return self
  257. class IOConsumer(threading.Thread):
  258. def __init__(self, opt, que, qid):
  259. super().__init__()
  260. self._queue = que
  261. self.qid = qid
  262. self.opt = opt
  263. def run(self):
  264. while True:
  265. msg = self._queue.get()
  266. if isinstance(msg, str) and msg == 'quit':
  267. break
  268. output = msg['output']
  269. save_path = msg['save_path']
  270. cv2.imwrite(save_path, output)
  271. print(f'IO worker {self.qid} is done.')