utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import torch
  6. from basicsr.archs.rrdbnet_arch import RRDBNet
  7. from basicsr.utils.download_util import load_file_from_url
  8. from torch.nn import functional as F
  9. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  10. class RealESRGANer():
  11. """A helper class for upsampling images with RealESRGAN.
  12. Args:
  13. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  14. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  15. model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
  16. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  17. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  18. 0 denotes for do not use tile. Default: 0.
  19. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  20. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  21. half (float): Whether to use half precision during inference. Default: False.
  22. """
  23. def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
  24. self.scale = scale
  25. self.tile_size = tile
  26. self.tile_pad = tile_pad
  27. self.pre_pad = pre_pad
  28. self.mod_scale = None
  29. self.half = half
  30. # initialize model
  31. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. if model is None:
  33. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
  34. # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
  35. if model_path.startswith('https://'):
  36. model_path = load_file_from_url(
  37. url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
  38. loadnet = torch.load(model_path)
  39. # prefer to use params_ema
  40. if 'params_ema' in loadnet:
  41. keyname = 'params_ema'
  42. else:
  43. keyname = 'params'
  44. model.load_state_dict(loadnet[keyname], strict=True)
  45. model.eval()
  46. self.model = model.to(self.device)
  47. if self.half:
  48. self.model = self.model.half()
  49. def pre_process(self, img):
  50. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
  51. """
  52. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  53. self.img = img.unsqueeze(0).to(self.device)
  54. if self.half:
  55. self.img = self.img.half()
  56. # pre_pad
  57. if self.pre_pad != 0:
  58. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
  59. # mod pad for divisible borders
  60. if self.scale == 2:
  61. self.mod_scale = 2
  62. elif self.scale == 1:
  63. self.mod_scale = 4
  64. if self.mod_scale is not None:
  65. self.mod_pad_h, self.mod_pad_w = 0, 0
  66. _, _, h, w = self.img.size()
  67. if (h % self.mod_scale != 0):
  68. self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
  69. if (w % self.mod_scale != 0):
  70. self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
  71. self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
  72. def process(self):
  73. # model inference
  74. self.output = self.model(self.img)
  75. def tile_process(self):
  76. """It will first crop input images to tiles, and then process each tile.
  77. Finally, all the processed tiles are merged into one images.
  78. Modified from: https://github.com/ata4/esrgan-launcher
  79. """
  80. batch, channel, height, width = self.img.shape
  81. output_height = height * self.scale
  82. output_width = width * self.scale
  83. output_shape = (batch, channel, output_height, output_width)
  84. # start with black image
  85. self.output = self.img.new_zeros(output_shape)
  86. tiles_x = math.ceil(width / self.tile_size)
  87. tiles_y = math.ceil(height / self.tile_size)
  88. # loop over all tiles
  89. for y in range(tiles_y):
  90. for x in range(tiles_x):
  91. # extract tile from input image
  92. ofs_x = x * self.tile_size
  93. ofs_y = y * self.tile_size
  94. # input tile area on total image
  95. input_start_x = ofs_x
  96. input_end_x = min(ofs_x + self.tile_size, width)
  97. input_start_y = ofs_y
  98. input_end_y = min(ofs_y + self.tile_size, height)
  99. # input tile area on total image with padding
  100. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  101. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  102. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  103. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  104. # input tile dimensions
  105. input_tile_width = input_end_x - input_start_x
  106. input_tile_height = input_end_y - input_start_y
  107. tile_idx = y * tiles_x + x + 1
  108. input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
  109. # upscale tile
  110. try:
  111. with torch.no_grad():
  112. output_tile = self.model(input_tile)
  113. except RuntimeError as error:
  114. print('Error', error)
  115. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  116. # output tile area on total image
  117. output_start_x = input_start_x * self.scale
  118. output_end_x = input_end_x * self.scale
  119. output_start_y = input_start_y * self.scale
  120. output_end_y = input_end_y * self.scale
  121. # output tile area without padding
  122. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  123. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  124. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  125. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  126. # put tile into output image
  127. self.output[:, :, output_start_y:output_end_y,
  128. output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
  129. output_start_x_tile:output_end_x_tile]
  130. def post_process(self):
  131. # remove extra pad
  132. if self.mod_scale is not None:
  133. _, _, h, w = self.output.size()
  134. self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
  135. # remove prepad
  136. if self.pre_pad != 0:
  137. _, _, h, w = self.output.size()
  138. self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
  139. return self.output
  140. @torch.no_grad()
  141. def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  142. h_input, w_input = img.shape[0:2]
  143. # img: numpy
  144. img = img.astype(np.float32)
  145. if np.max(img) > 256: # 16-bit image
  146. max_range = 65535
  147. print('\tInput is a 16-bit image')
  148. else:
  149. max_range = 255
  150. img = img / max_range
  151. if len(img.shape) == 2: # gray image
  152. img_mode = 'L'
  153. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  154. elif img.shape[2] == 4: # RGBA image with alpha channel
  155. img_mode = 'RGBA'
  156. alpha = img[:, :, 3]
  157. img = img[:, :, 0:3]
  158. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  159. if alpha_upsampler == 'realesrgan':
  160. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  161. else:
  162. img_mode = 'RGB'
  163. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  164. # ------------------- process image (without the alpha channel) ------------------- #
  165. self.pre_process(img)
  166. if self.tile_size > 0:
  167. self.tile_process()
  168. else:
  169. self.process()
  170. output_img = self.post_process()
  171. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  172. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  173. if img_mode == 'L':
  174. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  175. # ------------------- process the alpha channel if necessary ------------------- #
  176. if img_mode == 'RGBA':
  177. if alpha_upsampler == 'realesrgan':
  178. self.pre_process(alpha)
  179. if self.tile_size > 0:
  180. self.tile_process()
  181. else:
  182. self.process()
  183. output_alpha = self.post_process()
  184. output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  185. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  186. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  187. else: # use the cv2 resize for alpha channel
  188. h, w = alpha.shape[0:2]
  189. output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
  190. # merge the alpha channel
  191. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  192. output_img[:, :, 3] = output_alpha
  193. # ------------------------------ return ------------------------------ #
  194. if max_range == 65535: # 16-bit image
  195. output = (output_img * 65535.0).round().astype(np.uint16)
  196. else:
  197. output = (output_img * 255.0).round().astype(np.uint8)
  198. if outscale is not None and outscale != float(self.scale):
  199. output = cv2.resize(
  200. output, (
  201. int(w_input * outscale),
  202. int(h_input * outscale),
  203. ), interpolation=cv2.INTER_LANCZOS4)
  204. return output, img_mode