inference_realesrgan.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import cv2
  3. import glob
  4. import os
  5. from basicsr.archs.rrdbnet_arch import RRDBNet
  6. from realesrgan import RealESRGANer
  7. def main():
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
  10. parser.add_argument(
  11. '--model_path',
  12. type=str,
  13. default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
  14. help='Path to the pre-trained model')
  15. parser.add_argument('--output', type=str, default='results', help='Output folder')
  16. parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
  17. parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
  18. parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
  19. parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
  20. parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
  21. parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
  22. parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
  23. parser.add_argument('--half', action='store_true', help='Use half precision during inference')
  24. parser.add_argument('--block', type=int, default=23, help='num_block in RRDB')
  25. parser.add_argument(
  26. '--alpha_upsampler',
  27. type=str,
  28. default='realesrgan',
  29. help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
  30. parser.add_argument(
  31. '--ext',
  32. type=str,
  33. default='auto',
  34. help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
  35. args = parser.parse_args()
  36. if 'RealESRGAN_x4plus_anime_6B.pth' in args.model_path:
  37. args.block = 6
  38. elif 'RealESRGAN_x2plus.pth' in args.model_path:
  39. args.netscale = 2
  40. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=args.block, num_grow_ch=32, scale=args.netscale)
  41. upsampler = RealESRGANer(
  42. scale=args.netscale,
  43. model_path=args.model_path,
  44. model=model,
  45. tile=args.tile,
  46. tile_pad=args.tile_pad,
  47. pre_pad=args.pre_pad,
  48. half=args.half)
  49. if args.face_enhance:
  50. from gfpgan import GFPGANer
  51. face_enhancer = GFPGANer(
  52. model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
  53. upscale=args.outscale,
  54. arch='clean',
  55. channel_multiplier=2,
  56. bg_upsampler=upsampler)
  57. os.makedirs(args.output, exist_ok=True)
  58. if os.path.isfile(args.input):
  59. paths = [args.input]
  60. else:
  61. paths = sorted(glob.glob(os.path.join(args.input, '*')))
  62. for idx, path in enumerate(paths):
  63. imgname, extension = os.path.splitext(os.path.basename(path))
  64. print('Testing', idx, imgname)
  65. img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  66. if len(img.shape) == 3 and img.shape[2] == 4:
  67. img_mode = 'RGBA'
  68. else:
  69. img_mode = None
  70. h, w = img.shape[0:2]
  71. if max(h, w) > 1000 and args.netscale == 4:
  72. import warnings
  73. warnings.warn('The input image is large, try X2 model for better performace.')
  74. if max(h, w) < 500 and args.netscale == 2:
  75. import warnings
  76. warnings.warn('The input image is small, try X4 model for better performace.')
  77. try:
  78. if args.face_enhance:
  79. _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
  80. else:
  81. output, _ = upsampler.enhance(img, outscale=args.outscale)
  82. except Exception as error:
  83. print('Error', error)
  84. print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
  85. else:
  86. if args.ext == 'auto':
  87. extension = extension[1:]
  88. else:
  89. extension = args.ext
  90. if img_mode == 'RGBA': # RGBA images should be saved in png format
  91. extension = 'png'
  92. save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
  93. cv2.imwrite(save_path, output)
  94. if __name__ == '__main__':
  95. main()