utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import torch
  6. from basicsr.archs.rrdbnet_arch import RRDBNet
  7. from torch.hub import download_url_to_file, get_dir
  8. from torch.nn import functional as F
  9. from urllib.parse import urlparse
  10. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  11. class RealESRGANer():
  12. """A helper class for upsampling images with RealESRGAN.
  13. Args:
  14. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  15. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  16. model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
  17. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  18. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  19. 0 denotes for do not use tile. Default: 0.
  20. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  21. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  22. half (float): Whether to use half precision during inference. Default: False.
  23. """
  24. def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
  25. self.scale = scale
  26. self.tile_size = tile
  27. self.tile_pad = tile_pad
  28. self.pre_pad = pre_pad
  29. self.mod_scale = None
  30. self.half = half
  31. # initialize model
  32. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  33. if model is None:
  34. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
  35. # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
  36. if model_path.startswith('https://'):
  37. model_path = load_file_from_url(
  38. url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
  39. loadnet = torch.load(model_path)
  40. # prefer to use params_ema
  41. if 'params_ema' in loadnet:
  42. keyname = 'params_ema'
  43. else:
  44. keyname = 'params'
  45. model.load_state_dict(loadnet[keyname], strict=True)
  46. model.eval()
  47. self.model = model.to(self.device)
  48. if self.half:
  49. self.model = self.model.half()
  50. def pre_process(self, img):
  51. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  52. """
  53. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  54. self.img = img.unsqueeze(0).to(self.device)
  55. if self.half:
  56. self.img = self.img.half()
  57. # pre_pad
  58. if self.pre_pad != 0:
  59. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  60. # mod pad for divisible borders
  61. if self.scale == 2:
  62. self.mod_scale = 2
  63. elif self.scale == 1:
  64. self.mod_scale = 4
  65. if self.mod_scale is not None:
  66. self.mod_pad_h, self.mod_pad_w = 0, 0
  67. _, _, h, w = self.img.size()
  68. if (h % self.mod_scale != 0):
  69. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  70. if (w % self.mod_scale != 0):
  71. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  72. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  73. def process(self):
  74. # model inference
  75. self.output = self.model(self.img)
  76. def tile_process(self):
  77. """It will first crop input images to tiles, and then process each tile.
  78. Finally, all the processed tiles are merged into one images.
  79. Modified from: https://github.com/ata4/esrgan-launcher
  80. """
  81. batch, channel, height, width = self.img.shape
  82. output_height = height * self.scale
  83. output_width = width * self.scale
  84. output_shape = (batch, channel, output_height, output_width)
  85. # start with black image
  86. self.output = self.img.new_zeros(output_shape)
  87. tiles_x = math.ceil(width / self.tile_size)
  88. tiles_y = math.ceil(height / self.tile_size)
  89. # loop over all tiles
  90. for y in range(tiles_y):
  91. for x in range(tiles_x):
  92. # extract tile from input image
  93. ofs_x = x * self.tile_size
  94. ofs_y = y * self.tile_size
  95. # input tile area on total image
  96. input_start_x = ofs_x
  97. input_end_x = min(ofs_x + self.tile_size, width)
  98. input_start_y = ofs_y
  99. input_end_y = min(ofs_y + self.tile_size, height)
  100. # input tile area on total image with padding
  101. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  102. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  103. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  104. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  105. # input tile dimensions
  106. input_tile_width = input_end_x - input_start_x
  107. input_tile_height = input_end_y - input_start_y
  108. tile_idx = y * tiles_x + x + 1
  109. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  110. # upscale tile
  111. try:
  112. with torch.no_grad():
  113. output_tile = self.model(input_tile)
  114. except Exception as error:
  115. print('Error', error)
  116. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  117. # output tile area on total image
  118. output_start_x = input_start_x * self.scale
  119. output_end_x = input_end_x * self.scale
  120. output_start_y = input_start_y * self.scale
  121. output_end_y = input_end_y * self.scale
  122. # output tile area without padding
  123. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  124. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  125. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  126. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  127. # put tile into output image
  128. self.output[:, :, output_start_y:output_end_y,
  129. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  130. output_start_x_tile:output_end_x_tile]
  131. def post_process(self):
  132. # remove extra pad
  133. if self.mod_scale is not None:
  134. _, _, h, w = self.output.size()
  135. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  136. # remove prepad
  137. if self.pre_pad != 0:
  138. _, _, h, w = self.output.size()
  139. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  140. return self.output
  141. @torch.no_grad()
  142. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  143. h_input, w_input = img.shape[0:2]
  144. # img: numpy
  145. img = img.astype(np.float32)
  146. if np.max(img) > 256: # 16-bit image
  147. max_range = 65535
  148. print('\tInput is a 16-bit image')
  149. else:
  150. max_range = 255
  151. img = img / max_range
  152. if len(img.shape) == 2: # gray image
  153. img_mode = 'L'
  154. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  155. elif img.shape[2] == 4: # RGBA image with alpha channel
  156. img_mode = 'RGBA'
  157. alpha = img[:, :, 3]
  158. img = img[:, :, 0:3]
  159. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  160. if alpha_upsampler == 'realesrgan':
  161. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  162. else:
  163. img_mode = 'RGB'
  164. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  165. # ------------------- process image (without the alpha channel) ------------------- #
  166. self.pre_process(img)
  167. if self.tile_size > 0:
  168. self.tile_process()
  169. else:
  170. self.process()
  171. output_img = self.post_process()
  172. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  173. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  174. if img_mode == 'L':
  175. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  176. # ------------------- process the alpha channel if necessary ------------------- #
  177. if img_mode == 'RGBA':
  178. if alpha_upsampler == 'realesrgan':
  179. self.pre_process(alpha)
  180. if self.tile_size > 0:
  181. self.tile_process()
  182. else:
  183. self.process()
  184. output_alpha = self.post_process()
  185. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  186. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  187. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  188. else: # use the cv2 resize for alpha channel
  189. h, w = alpha.shape[0:2]
  190. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  191. # merge the alpha channel
  192. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  193. output_img[:, :, 3] = output_alpha
  194. # ------------------------------ return ------------------------------ #
  195. if max_range == 65535: # 16-bit image
  196. output = (output_img * 65535.0).round().astype(np.uint16)
  197. else:
  198. output = (output_img * 255.0).round().astype(np.uint8)
  199. if outscale is not None and outscale != float(self.scale):
  200. output = cv2.resize(
  201. output, (
  202. int(w_input * outscale),
  203. int(h_input * outscale),
  204. ), interpolation=cv2.INTER_LANCZOS4)
  205. return output, img_mode
  206. def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
  207. """Load file form http url, will download models if necessary.
  208. Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
  209. """
  210. if model_dir is None:
  211. hub_dir = get_dir()
  212. model_dir = os.path.join(hub_dir, 'checkpoints')
  213. os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
  214. parts = urlparse(url)
  215. filename = os.path.basename(parts.path)
  216. if file_name is not None:
  217. filename = file_name
  218. cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
  219. if not os.path.exists(cached_file):
  220. print(f'Downloading: "{url}" to {cached_file}\n')
  221. download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
  222. return cached_file