inference_realesrgan_video.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import argparse
  2. import cv2
  3. import glob
  4. import mimetypes
  5. import numpy as np
  6. import os
  7. import queue
  8. import shutil
  9. import torch
  10. from basicsr.archs.rrdbnet_arch import RRDBNet
  11. from basicsr.utils.logger import AvgTimer
  12. from tqdm import tqdm
  13. from realesrgan import IOConsumer, PrefetchReader, RealESRGANer
  14. from realesrgan.archs.srvgg_arch import SRVGGNetCompact
  15. def get_frames(args, extract_frames=False):
  16. # input can be a video file / a folder of frames / an image
  17. is_video = False
  18. if mimetypes.guess_type(args.input)[0].startswith('video'): # is a video file
  19. is_video = True
  20. video_name = os.path.splitext(os.path.basename(args.input))[0]
  21. if extract_frames:
  22. frame_folder = os.path.join('tmp_frames', video_name)
  23. os.makedirs(frame_folder, exist_ok=True)
  24. # use ffmpeg to extract frames
  25. os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {frame_folder}/frame%08d.png')
  26. # get image path list
  27. paths = sorted(glob.glob(os.path.join(frame_folder, '*')))
  28. else:
  29. paths = []
  30. # get input video fps
  31. if args.fps is None:
  32. import ffmpeg
  33. probe = ffmpeg.probe(args.input)
  34. video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
  35. args.fps = eval(video_streams[0]['avg_frame_rate'])
  36. elif mimetypes.guess_type(args.input)[0].startswith('image'): # is an image file
  37. paths = [args.input]
  38. else:
  39. paths = sorted(glob.glob(os.path.join(args.input, '*')))
  40. assert len(paths) > 0, 'the input folder is empty'
  41. if args.fps is None:
  42. args.fps = 24
  43. return is_video, paths
  44. def inference_stream(args, upsampler, face_enhancer):
  45. try:
  46. import ffmpeg
  47. except ImportError:
  48. import pip
  49. pip.main(['install', '--user', 'ffmpeg-python'])
  50. import ffmpeg
  51. is_video, paths = get_frames(args, extract_frames=False)
  52. video_name = os.path.splitext(os.path.basename(args.input))[0]
  53. video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
  54. # decoder
  55. if is_video:
  56. # get height and width
  57. probe = ffmpeg.probe(args.input)
  58. video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
  59. width = video_streams[0]['width']
  60. height = video_streams[0]['height']
  61. # set up frame decoder
  62. decoder = (
  63. ffmpeg.input(args.input).output('pipe:', format='rawvideo', pix_fmt='rgb24', loglevel='warning').run_async(
  64. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  65. else:
  66. from PIL import Image
  67. tmp_img = Image.open(paths[0])
  68. width, height = tmp_img.size
  69. idx = 0
  70. out_width, out_height = int(width * args.outscale), int(height * args.outscale)
  71. if out_height > 2160:
  72. print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
  73. 'We highly recommend to decrease the outscale(aka, -s).')
  74. # encoder
  75. if is_video:
  76. audio = ffmpeg.input(args.input).audio
  77. encoder = (
  78. ffmpeg.input(
  79. 'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=args.fps).output(
  80. audio, video_save_path, pix_fmt='yuv420p', vcodec='libx264', loglevel='info',
  81. acodec='copy').overwrite_output().run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  82. else:
  83. encoder = (
  84. ffmpeg.input(
  85. 'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}',
  86. framerate=args.fps).output(video_save_path, pix_fmt='yuv420p', vcodec='libx264',
  87. loglevel='info').overwrite_output().run_async(
  88. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  89. while True:
  90. if is_video:
  91. img_bytes = decoder.stdout.read(width * height * 3) # 3 bytes for one pixel
  92. if not img_bytes:
  93. break
  94. img = np.frombuffer(img_bytes, np.uint8).reshape([height, width, 3])
  95. else:
  96. if idx >= len(paths):
  97. break
  98. img = cv2.imread(paths[idx])
  99. idx += 1
  100. try:
  101. if args.face_enhance:
  102. _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
  103. else:
  104. output, _ = upsampler.enhance(img, outscale=args.outscale)
  105. except RuntimeError as error:
  106. print('Error', error)
  107. print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
  108. else:
  109. output = output.astype(np.uint8).tobytes()
  110. encoder.stdin.write(output)
  111. torch.cuda.synchronize()
  112. if is_video:
  113. decoder.stdin.close()
  114. decoder.wait()
  115. encoder.stdin.close()
  116. encoder.wait()
  117. def inference_frames(args, upsampler, face_enhancer):
  118. is_video, paths = get_frames(args, extract_frames=True)
  119. video_name = os.path.splitext(os.path.basename(args.input))[0]
  120. # for saving restored frames
  121. save_frame_folder = os.path.join(args.output, video_name, 'frames_tmpout')
  122. os.makedirs(save_frame_folder, exist_ok=True)
  123. timer = AvgTimer()
  124. timer.start()
  125. pbar = tqdm(total=len(paths), unit='frame', desc='inference')
  126. # set up prefetch reader
  127. reader = PrefetchReader(paths, num_prefetch_queue=4)
  128. reader.start()
  129. que = queue.Queue()
  130. consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.consumer)]
  131. for consumer in consumers:
  132. consumer.start()
  133. for idx, (path, img) in enumerate(zip(paths, reader)):
  134. imgname, extension = os.path.splitext(os.path.basename(path))
  135. if len(img.shape) == 3 and img.shape[2] == 4:
  136. img_mode = 'RGBA'
  137. else:
  138. img_mode = None
  139. try:
  140. if args.face_enhance:
  141. _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
  142. else:
  143. output, _ = upsampler.enhance(img, outscale=args.outscale)
  144. except RuntimeError as error:
  145. print('Error', error)
  146. print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
  147. else:
  148. if args.ext == 'auto':
  149. extension = extension[1:]
  150. else:
  151. extension = args.ext
  152. if img_mode == 'RGBA': # RGBA images should be saved in png format
  153. extension = 'png'
  154. save_path = os.path.join(save_frame_folder, f'{imgname}_out.{extension}')
  155. que.put({'output': output, 'save_path': save_path})
  156. pbar.update(1)
  157. torch.cuda.synchronize()
  158. timer.record()
  159. avg_fps = 1. / (timer.get_avg_time() + 1e-7)
  160. pbar.set_description(f'idx {idx}, fps {avg_fps:.2f}')
  161. for _ in range(args.consumer):
  162. que.put('quit')
  163. for consumer in consumers:
  164. consumer.join()
  165. pbar.close()
  166. # merge frames to video
  167. video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
  168. os.system(f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} -i {args.input}'
  169. f' -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}')
  170. # delete tmp file
  171. shutil.rmtree(save_frame_folder)
  172. frame_folder = os.path.join('tmp_frames', video_name)
  173. if os.path.isdir(frame_folder):
  174. shutil.rmtree(frame_folder)
  175. def main():
  176. """Inference demo for Real-ESRGAN.
  177. It mainly for restoring anime videos.
  178. """
  179. parser = argparse.ArgumentParser()
  180. parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder')
  181. parser.add_argument(
  182. '-n',
  183. '--model_name',
  184. type=str,
  185. default='realesr-animevideov3',
  186. help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |'
  187. ' RealESRGAN_x2plus | '
  188. 'Default:realesr-animevideov3'))
  189. parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
  190. parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
  191. parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
  192. parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
  193. parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
  194. parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
  195. parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
  196. parser.add_argument(
  197. '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
  198. parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
  199. parser.add_argument('--consumer', type=int, default=4, help='Number of IO consumers')
  200. parser.add_argument('--stream', action='store_true')
  201. parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
  202. parser.add_argument(
  203. '--alpha_upsampler',
  204. type=str,
  205. default='realesrgan',
  206. help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
  207. parser.add_argument(
  208. '--ext',
  209. type=str,
  210. default='auto',
  211. help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
  212. args = parser.parse_args()
  213. args.input = args.input.rstrip('/').rstrip('\\')
  214. # ---------------------- determine models according to model names ---------------------- #
  215. args.model_name = args.model_name.split('.pth')[0]
  216. if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
  217. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
  218. netscale = 4
  219. elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
  220. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
  221. netscale = 4
  222. elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
  223. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
  224. netscale = 2
  225. elif args.model_name in ['realesr-animevideov3']: # x4 VGG-style model (XS size)
  226. model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
  227. netscale = 4
  228. # ---------------------- determine model paths ---------------------- #
  229. model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
  230. if not os.path.isfile(model_path):
  231. model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
  232. if not os.path.isfile(model_path):
  233. raise ValueError(f'Model {args.model_name} does not exist.')
  234. # restorer
  235. upsampler = RealESRGANer(
  236. scale=netscale,
  237. model_path=model_path,
  238. model=model,
  239. tile=args.tile,
  240. tile_pad=args.tile_pad,
  241. pre_pad=args.pre_pad,
  242. half=not args.fp32)
  243. if 'anime' in args.model_name and args.face_enhance:
  244. print('face_enhance is not supported in anime models, we turned this option off for you. '
  245. 'if you insist on turning it on, please manually comment the relevant lines of code.')
  246. args.face_enhance = False
  247. if args.face_enhance: # Use GFPGAN for face enhancement
  248. from gfpgan import GFPGANer
  249. face_enhancer = GFPGANer(
  250. model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
  251. upscale=args.outscale,
  252. arch='clean',
  253. channel_multiplier=2,
  254. bg_upsampler=upsampler)
  255. else:
  256. face_enhancer = None
  257. os.makedirs(args.output, exist_ok=True)
  258. if args.stream:
  259. inference_stream(args, upsampler, face_enhancer)
  260. else:
  261. inference_frames(args, upsampler, face_enhancer)
  262. if __name__ == '__main__':
  263. main()