inference_realesrgan_video.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import argparse
  2. import cv2
  3. import glob
  4. import mimetypes
  5. import numpy as np
  6. import os
  7. import shutil
  8. import subprocess
  9. import torch
  10. from basicsr.archs.rrdbnet_arch import RRDBNet
  11. from os import path as osp
  12. from tqdm import tqdm
  13. from realesrgan import RealESRGANer
  14. from realesrgan.archs.srvgg_arch import SRVGGNetCompact
  15. try:
  16. import ffmpeg
  17. except ImportError:
  18. import pip
  19. pip.main(['install', '--user', 'ffmpeg-python'])
  20. import ffmpeg
  21. def get_video_meta_info(video_path):
  22. ret = {}
  23. probe = ffmpeg.probe(video_path)
  24. video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
  25. has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
  26. ret['width'] = video_streams[0]['width']
  27. ret['height'] = video_streams[0]['height']
  28. ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
  29. ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
  30. ret['nb_frames'] = int(video_streams[0]['nb_frames'])
  31. return ret
  32. def get_sub_video(args, num_process, process_idx):
  33. if num_process == 1:
  34. return args.input
  35. meta = get_video_meta_info(args.input)
  36. duration = int(meta['nb_frames'] / meta['fps'])
  37. part_time = duration // num_process
  38. print(f'duration: {duration}, part_time: {part_time}')
  39. os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
  40. out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
  41. cmd = [
  42. args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
  43. f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
  44. ]
  45. print(' '.join(cmd))
  46. subprocess.call(' '.join(cmd), shell=True)
  47. return out_path
  48. class Reader:
  49. def __init__(self, args, total_workers=1, worker_idx=0):
  50. self.args = args
  51. input_type = mimetypes.guess_type(args.input)[0]
  52. self.input_type = 'folder' if input_type is None else input_type
  53. self.paths = [] # for image&folder type
  54. self.audio = None
  55. self.input_fps = None
  56. if self.input_type.startswith('video'):
  57. video_path = get_sub_video(args, total_workers, worker_idx)
  58. self.stream_reader = (
  59. ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
  60. loglevel='error').run_async(
  61. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  62. meta = get_video_meta_info(video_path)
  63. self.width = meta['width']
  64. self.height = meta['height']
  65. self.input_fps = meta['fps']
  66. self.audio = meta['audio']
  67. self.nb_frames = meta['nb_frames']
  68. else:
  69. if self.input_type.startswith('image'):
  70. self.paths = [args.input]
  71. else:
  72. paths = sorted(glob.glob(os.path.join(args.input, '*')))
  73. tot_frames = len(paths)
  74. num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
  75. self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
  76. self.nb_frames = len(self.paths)
  77. assert self.nb_frames > 0, 'empty folder'
  78. from PIL import Image
  79. tmp_img = Image.open(self.paths[0])
  80. self.width, self.height = tmp_img.size
  81. self.idx = 0
  82. def get_resolution(self):
  83. return self.height, self.width
  84. def get_fps(self):
  85. if self.args.fps is not None:
  86. return self.args.fps
  87. elif self.input_fps is not None:
  88. return self.input_fps
  89. return 24
  90. def get_audio(self):
  91. return self.audio
  92. def __len__(self):
  93. return self.nb_frames
  94. def get_frame_from_stream(self):
  95. img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
  96. if not img_bytes:
  97. return None
  98. img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
  99. return img
  100. def get_frame_from_list(self):
  101. if self.idx >= self.nb_frames:
  102. return None
  103. img = cv2.imread(self.paths[self.idx])
  104. self.idx += 1
  105. return img
  106. def get_frame(self):
  107. if self.input_type.startswith('video'):
  108. return self.get_frame_from_stream()
  109. else:
  110. return self.get_frame_from_list()
  111. def close(self):
  112. if self.input_type.startswith('video'):
  113. self.stream_reader.stdin.close()
  114. self.stream_reader.wait()
  115. class Writer:
  116. def __init__(self, args, audio, height, width, video_save_path, fps):
  117. out_width, out_height = int(width * args.outscale), int(height * args.outscale)
  118. if out_height > 2160:
  119. print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
  120. 'We highly recommend to decrease the outscale(aka, -s).')
  121. if audio is not None:
  122. self.stream_writer = (
  123. ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
  124. framerate=fps).output(
  125. audio,
  126. video_save_path,
  127. pix_fmt='yuv420p',
  128. vcodec='libx264',
  129. loglevel='error',
  130. acodec='copy').overwrite_output().run_async(
  131. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  132. else:
  133. self.stream_writer = (
  134. ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
  135. framerate=fps).output(
  136. video_save_path, pix_fmt='yuv420p', vcodec='libx264',
  137. loglevel='error').overwrite_output().run_async(
  138. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  139. def write_frame(self, frame):
  140. frame = frame.astype(np.uint8).tobytes()
  141. self.stream_writer.stdin.write(frame)
  142. def close(self):
  143. self.stream_writer.stdin.close()
  144. self.stream_writer.wait()
  145. def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
  146. # ---------------------- determine models according to model names ---------------------- #
  147. args.model_name = args.model_name.split('.pth')[0]
  148. if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
  149. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
  150. netscale = 4
  151. elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
  152. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
  153. netscale = 4
  154. elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
  155. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
  156. netscale = 2
  157. elif args.model_name in ['realesr-animevideov3']: # x4 VGG-style model (XS size)
  158. model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
  159. netscale = 4
  160. else:
  161. raise NotImplementedError
  162. # ---------------------- determine model paths ---------------------- #
  163. model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
  164. if not os.path.isfile(model_path):
  165. model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
  166. if not os.path.isfile(model_path):
  167. raise ValueError(f'Model {args.model_name} does not exist.')
  168. # restorer
  169. upsampler = RealESRGANer(
  170. scale=netscale,
  171. model_path=model_path,
  172. model=model,
  173. tile=args.tile,
  174. tile_pad=args.tile_pad,
  175. pre_pad=args.pre_pad,
  176. half=not args.fp32,
  177. device=device,
  178. )
  179. if 'anime' in args.model_name and args.face_enhance:
  180. print('face_enhance is not supported in anime models, we turned this option off for you. '
  181. 'if you insist on turning it on, please manually comment the relevant lines of code.')
  182. args.face_enhance = False
  183. if args.face_enhance: # Use GFPGAN for face enhancement
  184. from gfpgan import GFPGANer
  185. face_enhancer = GFPGANer(
  186. model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
  187. upscale=args.outscale,
  188. arch='clean',
  189. channel_multiplier=2,
  190. bg_upsampler=upsampler) # TODO support custom device
  191. else:
  192. face_enhancer = None
  193. reader = Reader(args, total_workers, worker_idx)
  194. audio = reader.get_audio()
  195. height, width = reader.get_resolution()
  196. fps = reader.get_fps()
  197. writer = Writer(args, audio, height, width, video_save_path, fps)
  198. pbar = tqdm(total=len(reader), unit='frame', desc='inference')
  199. while True:
  200. img = reader.get_frame()
  201. if img is None:
  202. break
  203. try:
  204. if args.face_enhance:
  205. _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
  206. else:
  207. output, _ = upsampler.enhance(img, outscale=args.outscale)
  208. except RuntimeError as error:
  209. print('Error', error)
  210. print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
  211. else:
  212. writer.write_frame(output)
  213. torch.cuda.synchronize(device)
  214. pbar.update(1)
  215. reader.close()
  216. writer.close()
  217. def run(args):
  218. args.video_name = osp.splitext(os.path.basename(args.input))[0]
  219. video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
  220. if args.extract_frame_first:
  221. tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
  222. os.makedirs(tmp_frames_folder, exist_ok=True)
  223. os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
  224. args.input = tmp_frames_folder
  225. num_gpus = torch.cuda.device_count()
  226. num_process = num_gpus * args.num_process_per_gpu
  227. if num_process == 1:
  228. inference_video(args, video_save_path)
  229. return
  230. ctx = torch.multiprocessing.get_context('spawn')
  231. pool = ctx.Pool(num_process)
  232. os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
  233. pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
  234. for i in range(num_process):
  235. sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
  236. pool.apply_async(
  237. inference_video,
  238. args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
  239. callback=lambda arg: pbar.update(1))
  240. pool.close()
  241. pool.join()
  242. # combine sub videos
  243. # prepare vidlist.txt
  244. with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
  245. for i in range(num_process):
  246. f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
  247. cmd = [
  248. args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
  249. 'copy', f'{video_save_path}'
  250. ]
  251. print(' '.join(cmd))
  252. subprocess.call(cmd)
  253. shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
  254. if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
  255. shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
  256. os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
  257. def main():
  258. """Inference demo for Real-ESRGAN.
  259. It mainly for restoring anime videos.
  260. """
  261. parser = argparse.ArgumentParser()
  262. parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder')
  263. parser.add_argument(
  264. '-n',
  265. '--model_name',
  266. type=str,
  267. default='realesr-animevideov3',
  268. help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |'
  269. ' RealESRGAN_x2plus | '
  270. 'Default:realesr-animevideov3'))
  271. parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
  272. parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
  273. parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
  274. parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
  275. parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
  276. parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
  277. parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
  278. parser.add_argument(
  279. '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
  280. parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
  281. parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
  282. parser.add_argument('--extract_frame_first', action='store_true')
  283. parser.add_argument('--num_process_per_gpu', type=int, default=1)
  284. parser.add_argument(
  285. '--alpha_upsampler',
  286. type=str,
  287. default='realesrgan',
  288. help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
  289. parser.add_argument(
  290. '--ext',
  291. type=str,
  292. default='auto',
  293. help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
  294. args = parser.parse_args()
  295. args.input = args.input.rstrip('/').rstrip('\\')
  296. os.makedirs(args.output, exist_ok=True)
  297. if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
  298. is_video = True
  299. else:
  300. is_video = False
  301. if is_video and args.input.endswith('.flv'):
  302. mp4_path = args.input.replace('.flv', '.mp4')
  303. os.system(f'ffmpeg -i {args.input} -codec copy {mp4_path}')
  304. args.input = mp4_path
  305. if args.extract_frame_first and not is_video:
  306. args.extract_frame_first = False
  307. run(args)
  308. if args.extract_frame_first:
  309. tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
  310. shutil.rmtree(tmp_frames_folder)
  311. if __name__ == '__main__':
  312. main()