|
@@ -1,6 +1,7 @@
|
|
|
import argparse
|
|
|
import cv2
|
|
|
import glob
|
|
|
+import math
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import torch
|
|
@@ -10,64 +11,233 @@ from torch.nn import functional as F
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
- parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth')
|
|
|
- parser.add_argument('--scale', type=int, default=4)
|
|
|
- parser.add_argument('--suffix', type=str, default='_out')
|
|
|
- parser.add_argument('--input', type=str, default='inputs', help='input image or folder')
|
|
|
+ parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
|
|
|
+ parser.add_argument(
|
|
|
+ '--model_path',
|
|
|
+ type=str,
|
|
|
+ default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
|
|
+ help='Path to the pre-trained model')
|
|
|
+ parser.add_argument('--scale', type=int, default=4, help='Upsample scale factor')
|
|
|
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
|
|
+ parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
|
|
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
|
|
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
|
|
+ parser.add_argument(
|
|
|
+ '--alpha_upsampler',
|
|
|
+ type=str,
|
|
|
+ default='realesrgan',
|
|
|
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
|
|
+ parser.add_argument(
|
|
|
+ '--extension',
|
|
|
+ type=str,
|
|
|
+ default='auto',
|
|
|
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
- # set up model
|
|
|
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=args.scale)
|
|
|
- loadnet = torch.load(args.model_path)
|
|
|
- if 'params_ema' in loadnet:
|
|
|
- keyname = 'params_ema'
|
|
|
+ upsampler = RealESRGANer(
|
|
|
+ scale=args.scale, model_path=args.model_path, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad)
|
|
|
+ os.makedirs('results/', exist_ok=True)
|
|
|
+ if os.path.isfile(args.input):
|
|
|
+ paths = [args.input]
|
|
|
else:
|
|
|
- keyname = 'params'
|
|
|
- model.load_state_dict(loadnet[keyname], strict=True)
|
|
|
- model.eval()
|
|
|
- model = model.to(device)
|
|
|
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
|
|
|
|
|
- os.makedirs('results/', exist_ok=True)
|
|
|
- for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
|
|
|
- imgname = os.path.splitext(os.path.basename(path))[0]
|
|
|
+ for idx, path in enumerate(paths):
|
|
|
+ imgname, extension = os.path.splitext(os.path.basename(path))
|
|
|
print('Testing', idx, imgname)
|
|
|
- # read image
|
|
|
- img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
|
|
- img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
|
|
- img = img.unsqueeze(0).to(device)
|
|
|
-
|
|
|
- if args.scale == 2:
|
|
|
- mod_scale = 2
|
|
|
- elif args.scale == 1:
|
|
|
- mod_scale = 4
|
|
|
+
|
|
|
+ # ------------------------------ read image ------------------------------ #
|
|
|
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
|
|
+ if np.max(img) > 255: # 16-bit image
|
|
|
+ max_range = 65535
|
|
|
+ print('\tInput is a 16-bit image')
|
|
|
+ else:
|
|
|
+ max_range = 255
|
|
|
+ img = img / max_range
|
|
|
+ if len(img.shape) == 2: # gray image
|
|
|
+ img_mode = 'L'
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
|
+ elif img.shape[2] == 4: # RGBA image with alpha channel
|
|
|
+ img_mode = 'RGBA'
|
|
|
+ alpha = img[:, :, 3]
|
|
|
+ img = img[:, :, 0:3]
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+ if args.alpha_upsampler == 'realesrgan':
|
|
|
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
|
|
+ else:
|
|
|
+ img_mode = 'RGB'
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+
|
|
|
+ # ------------------- process image (without the alpha channel) ------------------- #
|
|
|
+ upsampler.pre_process(img)
|
|
|
+ if args.tile:
|
|
|
+ upsampler.tile_process()
|
|
|
+ else:
|
|
|
+ upsampler.process()
|
|
|
+ output_img = upsampler.post_process()
|
|
|
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
|
|
+ if img_mode == 'L':
|
|
|
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # ------------------- process the alpha channel if necessary ------------------- #
|
|
|
+ if img_mode == 'RGBA':
|
|
|
+ if args.alpha_upsampler == 'realesrgan':
|
|
|
+ upsampler.pre_process(alpha)
|
|
|
+ if args.tile:
|
|
|
+ upsampler.tile_process()
|
|
|
+ else:
|
|
|
+ upsampler.process()
|
|
|
+ output_alpha = upsampler.post_process()
|
|
|
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
|
|
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
|
|
+ else:
|
|
|
+ h, w = alpha.shape[0:2]
|
|
|
+ output_alpha = cv2.resize(alpha, (w * args.scale, h * args.scale), interpolation=cv2.INTER_LINEAR)
|
|
|
+
|
|
|
+ # merge the alpha channel
|
|
|
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
|
|
+ output_img[:, :, 3] = output_alpha
|
|
|
+
|
|
|
+ # ------------------------------ save image ------------------------------ #
|
|
|
+ if args.extension == 'auto':
|
|
|
+ extension = extension[1:]
|
|
|
+ else:
|
|
|
+ extension == args.extension
|
|
|
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
|
|
|
+ extension = 'png'
|
|
|
+ save_path = f'results/{imgname}_{args.suffix}.{extension}'
|
|
|
+ if max_range == 65535: # 16-bit image
|
|
|
+ output = (output_img * 65535.0).round().astype(np.uint16)
|
|
|
+ else:
|
|
|
+ output = (output_img * 255.0).round().astype(np.uint8)
|
|
|
+ cv2.imwrite(save_path, output)
|
|
|
+
|
|
|
+
|
|
|
+class RealESRGANer():
|
|
|
+
|
|
|
+ def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10):
|
|
|
+ self.scale = scale
|
|
|
+ self.tile_size = tile
|
|
|
+ self.tile_pad = tile_pad
|
|
|
+ self.pre_pad = pre_pad
|
|
|
+ self.mod_scale = None
|
|
|
+
|
|
|
+ # initialize model
|
|
|
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
|
|
+ loadnet = torch.load(model_path)
|
|
|
+ if 'params_ema' in loadnet:
|
|
|
+ keyname = 'params_ema'
|
|
|
else:
|
|
|
- mod_scale = None
|
|
|
- if mod_scale is not None:
|
|
|
- h_pad, w_pad = 0, 0
|
|
|
- _, _, h, w = img.size()
|
|
|
- if (h % mod_scale != 0):
|
|
|
- h_pad = (mod_scale - h % mod_scale)
|
|
|
- if (w % mod_scale != 0):
|
|
|
- w_pad = (mod_scale - w % mod_scale)
|
|
|
- img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
|
|
|
+ keyname = 'params'
|
|
|
+ model.load_state_dict(loadnet[keyname], strict=True)
|
|
|
+ model.eval()
|
|
|
+ self.model = model.to(self.device)
|
|
|
+
|
|
|
+ def pre_process(self, img):
|
|
|
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
|
|
+ self.img = img.unsqueeze(0).to(self.device)
|
|
|
+
|
|
|
+ # pre_pad
|
|
|
+ if self.pre_pad != 0:
|
|
|
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
|
|
+ # mod pad
|
|
|
+ if self.scale == 2:
|
|
|
+ self.mod_scale = 2
|
|
|
+ elif self.scale == 1:
|
|
|
+ self.mod_scale = 4
|
|
|
+ if self.mod_scale is not None:
|
|
|
+ self.mod_pad_h, self.mod_pad_w = 0, 0
|
|
|
+ _, _, h, w = self.img.size()
|
|
|
+ if (h % self.mod_scale != 0):
|
|
|
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
|
|
+ if (w % self.mod_scale != 0):
|
|
|
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
|
|
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
|
|
|
|
|
+ def process(self):
|
|
|
try:
|
|
|
# inference
|
|
|
with torch.no_grad():
|
|
|
- output = model(img)
|
|
|
- # remove extra pad
|
|
|
- if mod_scale is not None:
|
|
|
- _, _, h, w = output.size()
|
|
|
- output = output[:, :, 0:h - h_pad, 0:w - w_pad]
|
|
|
- # save image
|
|
|
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
|
|
- output = (output * 255.0).round().astype(np.uint8)
|
|
|
- cv2.imwrite(f'results/{imgname}_{args.suffix}.png', output)
|
|
|
+ self.output = self.model(self.img)
|
|
|
except Exception as error:
|
|
|
print('Error', error)
|
|
|
|
|
|
+ def tile_process(self):
|
|
|
+ """Modified from: https://github.com/ata4/esrgan-launcher
|
|
|
+ """
|
|
|
+ batch, channel, height, width = self.img.shape
|
|
|
+ output_height = height * self.scale
|
|
|
+ output_width = width * self.scale
|
|
|
+ output_shape = (batch, channel, output_height, output_width)
|
|
|
+
|
|
|
+ # start with black image
|
|
|
+ self.output = self.img.new_zeros(output_shape)
|
|
|
+ tiles_x = math.ceil(width / self.tile_size)
|
|
|
+ tiles_y = math.ceil(height / self.tile_size)
|
|
|
+
|
|
|
+ # loop over all tiles
|
|
|
+ for y in range(tiles_y):
|
|
|
+ for x in range(tiles_x):
|
|
|
+ # extract tile from input image
|
|
|
+ ofs_x = x * self.tile_size
|
|
|
+ ofs_y = y * self.tile_size
|
|
|
+ # input tile area on total image
|
|
|
+ input_start_x = ofs_x
|
|
|
+ input_end_x = min(ofs_x + self.tile_size, width)
|
|
|
+ input_start_y = ofs_y
|
|
|
+ input_end_y = min(ofs_y + self.tile_size, height)
|
|
|
+
|
|
|
+ # input tile area on total image with padding
|
|
|
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
|
|
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
|
|
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
|
|
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
|
|
+
|
|
|
+ # input tile dimensions
|
|
|
+ input_tile_width = input_end_x - input_start_x
|
|
|
+ input_tile_height = input_end_y - input_start_y
|
|
|
+ tile_idx = y * tiles_x + x + 1
|
|
|
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
|
|
+
|
|
|
+ # upscale tile
|
|
|
+ try:
|
|
|
+ with torch.no_grad():
|
|
|
+ output_tile = self.model(input_tile)
|
|
|
+ except Exception as error:
|
|
|
+ print('Error', error)
|
|
|
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
|
|
+
|
|
|
+ # output tile area on total image
|
|
|
+ output_start_x = input_start_x * self.scale
|
|
|
+ output_end_x = input_end_x * self.scale
|
|
|
+ output_start_y = input_start_y * self.scale
|
|
|
+ output_end_y = input_end_y * self.scale
|
|
|
+
|
|
|
+ # output tile area without padding
|
|
|
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
|
|
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
|
|
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
|
|
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
|
|
+
|
|
|
+ # put tile into output image
|
|
|
+ self.output[:, :, output_start_y:output_end_y,
|
|
|
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
|
|
+ output_start_x_tile:output_end_x_tile]
|
|
|
+
|
|
|
+ def post_process(self):
|
|
|
+ # remove extra pad
|
|
|
+ if self.mod_scale is not None:
|
|
|
+ _, _, h, w = self.output.size()
|
|
|
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
|
|
+ # remove prepad
|
|
|
+ if self.pre_pad != 0:
|
|
|
+ _, _, h, w = self.output.size()
|
|
|
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
|
|
+ return self.output
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|