utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import queue
  6. import threading
  7. import torch
  8. from basicsr.utils.download_util import load_file_from_url
  9. from torch.nn import functional as F
  10. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  11. class RealESRGANer():
  12. """A helper class for upsampling images with RealESRGAN.
  13. Args:
  14. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  15. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  16. model (nn.Module): The defined network. Default: None.
  17. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  18. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  19. 0 denotes for do not use tile. Default: 0.
  20. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  21. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  22. half (float): Whether to use half precision during inference. Default: False.
  23. """
  24. def __init__(self,
  25. scale,
  26. model_path,
  27. dni_weight=None,
  28. model=None,
  29. tile=0,
  30. tile_pad=10,
  31. pre_pad=10,
  32. half=False,
  33. device=None,
  34. gpu_id=None):
  35. self.scale = scale
  36. self.tile_size = tile
  37. self.tile_pad = tile_pad
  38. self.pre_pad = pre_pad
  39. self.mod_scale = None
  40. self.half = half
  41. # initialize model
  42. if gpu_id:
  43. self.device = torch.device(
  44. f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
  45. else:
  46. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
  47. if isinstance(model_path, list):
  48. # dni
  49. assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
  50. loadnet = self.dni(model_path[0], model_path[1], dni_weight)
  51. else:
  52. # if the model_path starts with https, it will first download models to the folder: weights
  53. if model_path.startswith('https://'):
  54. model_path = load_file_from_url(
  55. url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
  56. loadnet = torch.load(model_path, map_location=torch.device('cpu'))
  57. # prefer to use params_ema
  58. if 'params_ema' in loadnet:
  59. keyname = 'params_ema'
  60. else:
  61. keyname = 'params'
  62. model.load_state_dict(loadnet[keyname], strict=True)
  63. model.eval()
  64. self.model = model.to(self.device)
  65. if self.half:
  66. self.model = self.model.half()
  67. def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
  68. """Deep network interpolation.
  69. ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
  70. """
  71. net_a = torch.load(net_a, map_location=torch.device(loc))
  72. net_b = torch.load(net_b, map_location=torch.device(loc))
  73. for k, v_a in net_a[key].items():
  74. net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
  75. return net_a
  76. def pre_process(self, img):
  77. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  78. """
  79. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  80. self.img = img.unsqueeze(0).to(self.device)
  81. if self.half:
  82. self.img = self.img.half()
  83. # pre_pad
  84. if self.pre_pad != 0:
  85. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  86. # mod pad for divisible borders
  87. if self.scale == 2:
  88. self.mod_scale = 2
  89. elif self.scale == 1:
  90. self.mod_scale = 4
  91. if self.mod_scale is not None:
  92. self.mod_pad_h, self.mod_pad_w = 0, 0
  93. _, _, h, w = self.img.size()
  94. if (h % self.mod_scale != 0):
  95. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  96. if (w % self.mod_scale != 0):
  97. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  98. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  99. def process(self):
  100. # model inference
  101. self.output = self.model(self.img)
  102. def tile_process(self):
  103. """It will first crop input images to tiles, and then process each tile.
  104. Finally, all the processed tiles are merged into one images.
  105. Modified from: https://github.com/ata4/esrgan-launcher
  106. """
  107. batch, channel, height, width = self.img.shape
  108. output_height = height * self.scale
  109. output_width = width * self.scale
  110. output_shape = (batch, channel, output_height, output_width)
  111. # start with black image
  112. self.output = self.img.new_zeros(output_shape)
  113. tiles_x = math.ceil(width / self.tile_size)
  114. tiles_y = math.ceil(height / self.tile_size)
  115. # loop over all tiles
  116. for y in range(tiles_y):
  117. for x in range(tiles_x):
  118. # extract tile from input image
  119. ofs_x = x * self.tile_size
  120. ofs_y = y * self.tile_size
  121. # input tile area on total image
  122. input_start_x = ofs_x
  123. input_end_x = min(ofs_x + self.tile_size, width)
  124. input_start_y = ofs_y
  125. input_end_y = min(ofs_y + self.tile_size, height)
  126. # input tile area on total image with padding
  127. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  128. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  129. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  130. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  131. # input tile dimensions
  132. input_tile_width = input_end_x - input_start_x
  133. input_tile_height = input_end_y - input_start_y
  134. tile_idx = y * tiles_x + x + 1
  135. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  136. # upscale tile
  137. try:
  138. with torch.no_grad():
  139. output_tile = self.model(input_tile)
  140. except RuntimeError as error:
  141. print('Error', error)
  142. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  143. # output tile area on total image
  144. output_start_x = input_start_x * self.scale
  145. output_end_x = input_end_x * self.scale
  146. output_start_y = input_start_y * self.scale
  147. output_end_y = input_end_y * self.scale
  148. # output tile area without padding
  149. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  150. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  151. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  152. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  153. # put tile into output image
  154. self.output[:, :, output_start_y:output_end_y,
  155. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  156. output_start_x_tile:output_end_x_tile]
  157. def post_process(self):
  158. # remove extra pad
  159. if self.mod_scale is not None:
  160. _, _, h, w = self.output.size()
  161. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  162. # remove prepad
  163. if self.pre_pad != 0:
  164. _, _, h, w = self.output.size()
  165. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  166. return self.output
  167. @torch.no_grad()
  168. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  169. h_input, w_input = img.shape[0:2]
  170. # img: numpy
  171. img = img.astype(np.float32)
  172. if np.max(img) > 256: # 16-bit image
  173. max_range = 65535
  174. print('\tInput is a 16-bit image')
  175. else:
  176. max_range = 255
  177. img = img / max_range
  178. if len(img.shape) == 2: # gray image
  179. img_mode = 'L'
  180. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  181. elif img.shape[2] == 4: # RGBA image with alpha channel
  182. img_mode = 'RGBA'
  183. alpha = img[:, :, 3]
  184. img = img[:, :, 0:3]
  185. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  186. if alpha_upsampler == 'realesrgan':
  187. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  188. else:
  189. img_mode = 'RGB'
  190. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  191. # ------------------- process image (without the alpha channel) ------------------- #
  192. self.pre_process(img)
  193. if self.tile_size > 0:
  194. self.tile_process()
  195. else:
  196. self.process()
  197. output_img = self.post_process()
  198. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  199. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  200. if img_mode == 'L':
  201. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  202. # ------------------- process the alpha channel if necessary ------------------- #
  203. if img_mode == 'RGBA':
  204. if alpha_upsampler == 'realesrgan':
  205. self.pre_process(alpha)
  206. if self.tile_size > 0:
  207. self.tile_process()
  208. else:
  209. self.process()
  210. output_alpha = self.post_process()
  211. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  212. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  213. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  214. else: # use the cv2 resize for alpha channel
  215. h, w = alpha.shape[0:2]
  216. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  217. # merge the alpha channel
  218. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  219. output_img[:, :, 3] = output_alpha
  220. # ------------------------------ return ------------------------------ #
  221. if max_range == 65535: # 16-bit image
  222. output = (output_img * 65535.0).round().astype(np.uint16)
  223. else:
  224. output = (output_img * 255.0).round().astype(np.uint8)
  225. if outscale is not None and outscale != float(self.scale):
  226. output = cv2.resize(
  227. output, (
  228. int(w_input * outscale),
  229. int(h_input * outscale),
  230. ), interpolation=cv2.INTER_LANCZOS4)
  231. return output, img_mode
  232. class PrefetchReader(threading.Thread):
  233. """Prefetch images.
  234. Args:
  235. img_list (list[str]): A image list of image paths to be read.
  236. num_prefetch_queue (int): Number of prefetch queue.
  237. """
  238. def __init__(self, img_list, num_prefetch_queue):
  239. super().__init__()
  240. self.que = queue.Queue(num_prefetch_queue)
  241. self.img_list = img_list
  242. def run(self):
  243. for img_path in self.img_list:
  244. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  245. self.que.put(img)
  246. self.que.put(None)
  247. def __next__(self):
  248. next_item = self.que.get()
  249. if next_item is None:
  250. raise StopIteration
  251. return next_item
  252. def __iter__(self):
  253. return self
  254. class IOConsumer(threading.Thread):
  255. def __init__(self, opt, que, qid):
  256. super().__init__()
  257. self._queue = que
  258. self.qid = qid
  259. self.opt = opt
  260. def run(self):
  261. while True:
  262. msg = self._queue.get()
  263. if isinstance(msg, str) and msg == 'quit':
  264. break
  265. output = msg['output']
  266. save_path = msg['save_path']
  267. cv2.imwrite(save_path, output)
  268. print(f'IO worker {self.qid} is done.')