utils.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import torch
  6. from basicsr.utils.download_util import load_file_from_url
  7. from torch.nn import functional as F
  8. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  9. class RealESRGANer():
  10. """A helper class for upsampling images with RealESRGAN.
  11. Args:
  12. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  13. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  14. model (nn.Module): The defined network. Default: None.
  15. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  16. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  17. 0 denotes for do not use tile. Default: 0.
  18. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  19. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  20. half (float): Whether to use half precision during inference. Default: False.
  21. """
  22. def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
  23. self.scale = scale
  24. self.tile_size = tile
  25. self.tile_pad = tile_pad
  26. self.pre_pad = pre_pad
  27. self.mod_scale = None
  28. self.half = half
  29. # initialize model
  30. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  31. # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
  32. if model_path.startswith('https://'):
  33. model_path = load_file_from_url(
  34. url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
  35. loadnet = torch.load(model_path)
  36. # prefer to use params_ema
  37. if 'params_ema' in loadnet:
  38. keyname = 'params_ema'
  39. else:
  40. keyname = 'params'
  41. model.load_state_dict(loadnet[keyname], strict=True)
  42. model.eval()
  43. self.model = model.to(self.device)
  44. if self.half:
  45. self.model = self.model.half()
  46. def pre_process(self, img):
  47. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  48. """
  49. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  50. self.img = img.unsqueeze(0).to(self.device)
  51. if self.half:
  52. self.img = self.img.half()
  53. # pre_pad
  54. if self.pre_pad != 0:
  55. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  56. # mod pad for divisible borders
  57. if self.scale == 2:
  58. self.mod_scale = 2
  59. elif self.scale == 1:
  60. self.mod_scale = 4
  61. if self.mod_scale is not None:
  62. self.mod_pad_h, self.mod_pad_w = 0, 0
  63. _, _, h, w = self.img.size()
  64. if (h % self.mod_scale != 0):
  65. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  66. if (w % self.mod_scale != 0):
  67. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  68. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  69. def process(self):
  70. # model inference
  71. self.output = self.model(self.img)
  72. def tile_process(self):
  73. """It will first crop input images to tiles, and then process each tile.
  74. Finally, all the processed tiles are merged into one images.
  75. Modified from: https://github.com/ata4/esrgan-launcher
  76. """
  77. batch, channel, height, width = self.img.shape
  78. output_height = height * self.scale
  79. output_width = width * self.scale
  80. output_shape = (batch, channel, output_height, output_width)
  81. # start with black image
  82. self.output = self.img.new_zeros(output_shape)
  83. tiles_x = math.ceil(width / self.tile_size)
  84. tiles_y = math.ceil(height / self.tile_size)
  85. # loop over all tiles
  86. for y in range(tiles_y):
  87. for x in range(tiles_x):
  88. # extract tile from input image
  89. ofs_x = x * self.tile_size
  90. ofs_y = y * self.tile_size
  91. # input tile area on total image
  92. input_start_x = ofs_x
  93. input_end_x = min(ofs_x + self.tile_size, width)
  94. input_start_y = ofs_y
  95. input_end_y = min(ofs_y + self.tile_size, height)
  96. # input tile area on total image with padding
  97. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  98. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  99. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  100. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  101. # input tile dimensions
  102. input_tile_width = input_end_x - input_start_x
  103. input_tile_height = input_end_y - input_start_y
  104. tile_idx = y * tiles_x + x + 1
  105. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  106. # upscale tile
  107. try:
  108. with torch.no_grad():
  109. output_tile = self.model(input_tile)
  110. except RuntimeError as error:
  111. print('Error', error)
  112. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  113. # output tile area on total image
  114. output_start_x = input_start_x * self.scale
  115. output_end_x = input_end_x * self.scale
  116. output_start_y = input_start_y * self.scale
  117. output_end_y = input_end_y * self.scale
  118. # output tile area without padding
  119. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  120. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  121. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  122. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  123. # put tile into output image
  124. self.output[:, :, output_start_y:output_end_y,
  125. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  126. output_start_x_tile:output_end_x_tile]
  127. def post_process(self):
  128. # remove extra pad
  129. if self.mod_scale is not None:
  130. _, _, h, w = self.output.size()
  131. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  132. # remove prepad
  133. if self.pre_pad != 0:
  134. _, _, h, w = self.output.size()
  135. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  136. return self.output
  137. @torch.no_grad()
  138. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  139. h_input, w_input = img.shape[0:2]
  140. # img: numpy
  141. img = img.astype(np.float32)
  142. if np.max(img) > 256: # 16-bit image
  143. max_range = 65535
  144. print('\tInput is a 16-bit image')
  145. else:
  146. max_range = 255
  147. img = img / max_range
  148. if len(img.shape) == 2: # gray image
  149. img_mode = 'L'
  150. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  151. elif img.shape[2] == 4: # RGBA image with alpha channel
  152. img_mode = 'RGBA'
  153. alpha = img[:, :, 3]
  154. img = img[:, :, 0:3]
  155. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  156. if alpha_upsampler == 'realesrgan':
  157. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  158. else:
  159. img_mode = 'RGB'
  160. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  161. # ------------------- process image (without the alpha channel) ------------------- #
  162. self.pre_process(img)
  163. if self.tile_size > 0:
  164. self.tile_process()
  165. else:
  166. self.process()
  167. output_img = self.post_process()
  168. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  169. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  170. if img_mode == 'L':
  171. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  172. # ------------------- process the alpha channel if necessary ------------------- #
  173. if img_mode == 'RGBA':
  174. if alpha_upsampler == 'realesrgan':
  175. self.pre_process(alpha)
  176. if self.tile_size > 0:
  177. self.tile_process()
  178. else:
  179. self.process()
  180. output_alpha = self.post_process()
  181. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  182. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  183. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  184. else: # use the cv2 resize for alpha channel
  185. h, w = alpha.shape[0:2]
  186. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  187. # merge the alpha channel
  188. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  189. output_img[:, :, 3] = output_alpha
  190. # ------------------------------ return ------------------------------ #
  191. if max_range == 65535: # 16-bit image
  192. output = (output_img * 65535.0).round().astype(np.uint16)
  193. else:
  194. output = (output_img * 255.0).round().astype(np.uint8)
  195. if outscale is not None and outscale != float(self.scale):
  196. output = cv2.resize(
  197. output, (
  198. int(w_input * outscale),
  199. int(h_input * outscale),
  200. ), interpolation=cv2.INTER_LANCZOS4)
  201. return output, img_mode