inference_realesrgan.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import argparse
  2. import cv2
  3. import glob
  4. import os
  5. from realesrgan import RealESRGANer
  6. def main():
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
  9. parser.add_argument(
  10. '--model_path',
  11. type=str,
  12. default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
  13. help='Path to the pre-trained model')
  14. parser.add_argument('--output', type=str, default='results', help='Output folder')
  15. parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
  16. parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
  17. parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
  18. parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
  19. parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
  20. parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
  21. parser.add_argument('--half', action='store_true', help='Use half precision during inference')
  22. parser.add_argument(
  23. '--alpha_upsampler',
  24. type=str,
  25. default='realesrgan',
  26. help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
  27. parser.add_argument(
  28. '--ext',
  29. type=str,
  30. default='auto',
  31. help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
  32. args = parser.parse_args()
  33. upsampler = RealESRGANer(
  34. scale=args.netscale,
  35. model_path=args.model_path,
  36. tile=args.tile,
  37. tile_pad=args.tile_pad,
  38. pre_pad=args.pre_pad,
  39. half=args.half)
  40. os.makedirs(args.output, exist_ok=True)
  41. if os.path.isfile(args.input):
  42. paths = [args.input]
  43. else:
  44. paths = sorted(glob.glob(os.path.join(args.input, '*')))
  45. for idx, path in enumerate(paths):
  46. imgname, extension = os.path.splitext(os.path.basename(path))
  47. print('Testing', idx, imgname)
  48. img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  49. h, w = img.shape[0:2]
  50. if max(h, w) > 1000 and args.netscale == 4:
  51. print('WARNING: The input image is large, try X2 model for better performace.')
  52. if max(h, w) < 500 and args.netscale == 2:
  53. print('WARNING: The input image is small, try X4 model for better performace.')
  54. try:
  55. output, img_mode = upsampler.enhance(img, outscale=args.outscale)
  56. except Exception as error:
  57. print('Error', error)
  58. else:
  59. if args.ext == 'auto':
  60. extension = extension[1:]
  61. else:
  62. extension = args.ext
  63. if img_mode == 'RGBA': # RGBA images should be saved in png format
  64. extension = 'png'
  65. save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
  66. cv2.imwrite(save_path, output)
  67. if __name__ == '__main__':
  68. main()