utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import queue
  6. import threading
  7. import torch
  8. from basicsr.utils.download_util import load_file_from_url
  9. from torch.nn import functional as F
  10. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  11. class RealESRGANer():
  12. """A helper class for upsampling images with RealESRGAN.
  13. Args:
  14. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  15. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  16. model (nn.Module): The defined network. Default: None.
  17. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  18. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  19. 0 denotes for do not use tile. Default: 0.
  20. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  21. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  22. half (float): Whether to use half precision during inference. Default: False.
  23. """
  24. def __init__(self,
  25. scale,
  26. model_path,
  27. model=None,
  28. tile=0,
  29. tile_pad=10,
  30. pre_pad=10,
  31. half=False,
  32. device=None,
  33. gpu_id=None):
  34. self.scale = scale
  35. self.tile_size = tile
  36. self.tile_pad = tile_pad
  37. self.pre_pad = pre_pad
  38. self.mod_scale = None
  39. self.half = half
  40. # initialize model
  41. if gpu_id:
  42. self.device = torch.device(
  43. f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
  44. else:
  45. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
  46. # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
  47. if model_path.startswith('https://'):
  48. model_path = load_file_from_url(
  49. url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
  50. loadnet = torch.load(model_path, map_location=torch.device('cpu'))
  51. # prefer to use params_ema
  52. if 'params_ema' in loadnet:
  53. keyname = 'params_ema'
  54. else:
  55. keyname = 'params'
  56. model.load_state_dict(loadnet[keyname], strict=True)
  57. model.eval()
  58. self.model = model.to(self.device)
  59. if self.half:
  60. self.model = self.model.half()
  61. def pre_process(self, img):
  62. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  63. """
  64. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  65. self.img = img.unsqueeze(0).to(self.device)
  66. if self.half:
  67. self.img = self.img.half()
  68. # pre_pad
  69. if self.pre_pad != 0:
  70. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  71. # mod pad for divisible borders
  72. if self.scale == 2:
  73. self.mod_scale = 2
  74. elif self.scale == 1:
  75. self.mod_scale = 4
  76. if self.mod_scale is not None:
  77. self.mod_pad_h, self.mod_pad_w = 0, 0
  78. _, _, h, w = self.img.size()
  79. if (h % self.mod_scale != 0):
  80. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  81. if (w % self.mod_scale != 0):
  82. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  83. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  84. def process(self):
  85. # model inference
  86. self.output = self.model(self.img)
  87. def tile_process(self):
  88. """It will first crop input images to tiles, and then process each tile.
  89. Finally, all the processed tiles are merged into one images.
  90. Modified from: https://github.com/ata4/esrgan-launcher
  91. """
  92. batch, channel, height, width = self.img.shape
  93. output_height = height * self.scale
  94. output_width = width * self.scale
  95. output_shape = (batch, channel, output_height, output_width)
  96. # start with black image
  97. self.output = self.img.new_zeros(output_shape)
  98. tiles_x = math.ceil(width / self.tile_size)
  99. tiles_y = math.ceil(height / self.tile_size)
  100. # loop over all tiles
  101. for y in range(tiles_y):
  102. for x in range(tiles_x):
  103. # extract tile from input image
  104. ofs_x = x * self.tile_size
  105. ofs_y = y * self.tile_size
  106. # input tile area on total image
  107. input_start_x = ofs_x
  108. input_end_x = min(ofs_x + self.tile_size, width)
  109. input_start_y = ofs_y
  110. input_end_y = min(ofs_y + self.tile_size, height)
  111. # input tile area on total image with padding
  112. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  113. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  114. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  115. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  116. # input tile dimensions
  117. input_tile_width = input_end_x - input_start_x
  118. input_tile_height = input_end_y - input_start_y
  119. tile_idx = y * tiles_x + x + 1
  120. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  121. # upscale tile
  122. try:
  123. with torch.no_grad():
  124. output_tile = self.model(input_tile)
  125. except RuntimeError as error:
  126. print('Error', error)
  127. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  128. # output tile area on total image
  129. output_start_x = input_start_x * self.scale
  130. output_end_x = input_end_x * self.scale
  131. output_start_y = input_start_y * self.scale
  132. output_end_y = input_end_y * self.scale
  133. # output tile area without padding
  134. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  135. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  136. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  137. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  138. # put tile into output image
  139. self.output[:, :, output_start_y:output_end_y,
  140. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  141. output_start_x_tile:output_end_x_tile]
  142. def post_process(self):
  143. # remove extra pad
  144. if self.mod_scale is not None:
  145. _, _, h, w = self.output.size()
  146. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  147. # remove prepad
  148. if self.pre_pad != 0:
  149. _, _, h, w = self.output.size()
  150. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  151. return self.output
  152. @torch.no_grad()
  153. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  154. h_input, w_input = img.shape[0:2]
  155. # img: numpy
  156. img = img.astype(np.float32)
  157. if np.max(img) > 256: # 16-bit image
  158. max_range = 65535
  159. print('\tInput is a 16-bit image')
  160. else:
  161. max_range = 255
  162. img = img / max_range
  163. if len(img.shape) == 2: # gray image
  164. img_mode = 'L'
  165. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  166. elif img.shape[2] == 4: # RGBA image with alpha channel
  167. img_mode = 'RGBA'
  168. alpha = img[:, :, 3]
  169. img = img[:, :, 0:3]
  170. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  171. if alpha_upsampler == 'realesrgan':
  172. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  173. else:
  174. img_mode = 'RGB'
  175. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  176. # ------------------- process image (without the alpha channel) ------------------- #
  177. self.pre_process(img)
  178. if self.tile_size > 0:
  179. self.tile_process()
  180. else:
  181. self.process()
  182. output_img = self.post_process()
  183. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  184. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  185. if img_mode == 'L':
  186. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  187. # ------------------- process the alpha channel if necessary ------------------- #
  188. if img_mode == 'RGBA':
  189. if alpha_upsampler == 'realesrgan':
  190. self.pre_process(alpha)
  191. if self.tile_size > 0:
  192. self.tile_process()
  193. else:
  194. self.process()
  195. output_alpha = self.post_process()
  196. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  197. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  198. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  199. else: # use the cv2 resize for alpha channel
  200. h, w = alpha.shape[0:2]
  201. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  202. # merge the alpha channel
  203. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  204. output_img[:, :, 3] = output_alpha
  205. # ------------------------------ return ------------------------------ #
  206. if max_range == 65535: # 16-bit image
  207. output = (output_img * 65535.0).round().astype(np.uint16)
  208. else:
  209. output = (output_img * 255.0).round().astype(np.uint8)
  210. if outscale is not None and outscale != float(self.scale):
  211. output = cv2.resize(
  212. output, (
  213. int(w_input * outscale),
  214. int(h_input * outscale),
  215. ), interpolation=cv2.INTER_LANCZOS4)
  216. return output, img_mode
  217. class PrefetchReader(threading.Thread):
  218. """Prefetch images.
  219. Args:
  220. img_list (list[str]): A image list of image paths to be read.
  221. num_prefetch_queue (int): Number of prefetch queue.
  222. """
  223. def __init__(self, img_list, num_prefetch_queue):
  224. super().__init__()
  225. self.que = queue.Queue(num_prefetch_queue)
  226. self.img_list = img_list
  227. def run(self):
  228. for img_path in self.img_list:
  229. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  230. self.que.put(img)
  231. self.que.put(None)
  232. def __next__(self):
  233. next_item = self.que.get()
  234. if next_item is None:
  235. raise StopIteration
  236. return next_item
  237. def __iter__(self):
  238. return self
  239. class IOConsumer(threading.Thread):
  240. def __init__(self, opt, que, qid):
  241. super().__init__()
  242. self._queue = que
  243. self.qid = qid
  244. self.opt = opt
  245. def run(self):
  246. while True:
  247. msg = self._queue.get()
  248. if isinstance(msg, str) and msg == 'quit':
  249. break
  250. output = msg['output']
  251. save_path = msg['save_path']
  252. cv2.imwrite(save_path, output)
  253. print(f'IO worker {self.qid} is done.')