inference_realesrgan_video.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import argparse
  2. import cv2
  3. import glob
  4. import mimetypes
  5. import numpy as np
  6. import os
  7. import shutil
  8. import subprocess
  9. import torch
  10. from basicsr.archs.rrdbnet_arch import RRDBNet
  11. from basicsr.utils.download_util import load_file_from_url
  12. from os import path as osp
  13. from tqdm import tqdm
  14. from realesrgan import RealESRGANer
  15. from realesrgan.archs.srvgg_arch import SRVGGNetCompact
  16. try:
  17. import ffmpeg
  18. except ImportError:
  19. import pip
  20. pip.main(['install', '--user', 'ffmpeg-python'])
  21. import ffmpeg
  22. def get_video_meta_info(video_path):
  23. ret = {}
  24. probe = ffmpeg.probe(video_path)
  25. video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
  26. has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
  27. ret['width'] = video_streams[0]['width']
  28. ret['height'] = video_streams[0]['height']
  29. ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
  30. ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
  31. ret['nb_frames'] = int(video_streams[0]['nb_frames'])
  32. return ret
  33. def get_sub_video(args, num_process, process_idx):
  34. if num_process == 1:
  35. return args.input
  36. meta = get_video_meta_info(args.input)
  37. duration = int(meta['nb_frames'] / meta['fps'])
  38. part_time = duration // num_process
  39. print(f'duration: {duration}, part_time: {part_time}')
  40. os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
  41. out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
  42. cmd = [
  43. args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
  44. f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
  45. ]
  46. print(' '.join(cmd))
  47. subprocess.call(' '.join(cmd), shell=True)
  48. return out_path
  49. class Reader:
  50. def __init__(self, args, total_workers=1, worker_idx=0):
  51. self.args = args
  52. input_type = mimetypes.guess_type(args.input)[0]
  53. self.input_type = 'folder' if input_type is None else input_type
  54. self.paths = [] # for image&folder type
  55. self.audio = None
  56. self.input_fps = None
  57. if self.input_type.startswith('video'):
  58. video_path = get_sub_video(args, total_workers, worker_idx)
  59. self.stream_reader = (
  60. ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
  61. loglevel='error').run_async(
  62. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  63. meta = get_video_meta_info(video_path)
  64. self.width = meta['width']
  65. self.height = meta['height']
  66. self.input_fps = meta['fps']
  67. self.audio = meta['audio']
  68. self.nb_frames = meta['nb_frames']
  69. else:
  70. if self.input_type.startswith('image'):
  71. self.paths = [args.input]
  72. else:
  73. paths = sorted(glob.glob(os.path.join(args.input, '*')))
  74. tot_frames = len(paths)
  75. num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
  76. self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
  77. self.nb_frames = len(self.paths)
  78. assert self.nb_frames > 0, 'empty folder'
  79. from PIL import Image
  80. tmp_img = Image.open(self.paths[0])
  81. self.width, self.height = tmp_img.size
  82. self.idx = 0
  83. def get_resolution(self):
  84. return self.height, self.width
  85. def get_fps(self):
  86. if self.args.fps is not None:
  87. return self.args.fps
  88. elif self.input_fps is not None:
  89. return self.input_fps
  90. return 24
  91. def get_audio(self):
  92. return self.audio
  93. def __len__(self):
  94. return self.nb_frames
  95. def get_frame_from_stream(self):
  96. img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
  97. if not img_bytes:
  98. return None
  99. img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
  100. return img
  101. def get_frame_from_list(self):
  102. if self.idx >= self.nb_frames:
  103. return None
  104. img = cv2.imread(self.paths[self.idx])
  105. self.idx += 1
  106. return img
  107. def get_frame(self):
  108. if self.input_type.startswith('video'):
  109. return self.get_frame_from_stream()
  110. else:
  111. return self.get_frame_from_list()
  112. def close(self):
  113. if self.input_type.startswith('video'):
  114. self.stream_reader.stdin.close()
  115. self.stream_reader.wait()
  116. class Writer:
  117. def __init__(self, args, audio, height, width, video_save_path, fps):
  118. out_width, out_height = int(width * args.outscale), int(height * args.outscale)
  119. if out_height > 2160:
  120. print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
  121. 'We highly recommend to decrease the outscale(aka, -s).')
  122. if audio is not None:
  123. self.stream_writer = (
  124. ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
  125. framerate=fps).output(
  126. audio,
  127. video_save_path,
  128. pix_fmt='yuv420p',
  129. vcodec='libx264',
  130. loglevel='error',
  131. acodec='copy').overwrite_output().run_async(
  132. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  133. else:
  134. self.stream_writer = (
  135. ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
  136. framerate=fps).output(
  137. video_save_path, pix_fmt='yuv420p', vcodec='libx264',
  138. loglevel='error').overwrite_output().run_async(
  139. pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
  140. def write_frame(self, frame):
  141. frame = frame.astype(np.uint8).tobytes()
  142. self.stream_writer.stdin.write(frame)
  143. def close(self):
  144. self.stream_writer.stdin.close()
  145. self.stream_writer.wait()
  146. def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
  147. # ---------------------- determine models according to model names ---------------------- #
  148. args.model_name = args.model_name.split('.pth')[0]
  149. if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
  150. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
  151. netscale = 4
  152. file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
  153. elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
  154. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
  155. netscale = 4
  156. file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
  157. elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
  158. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
  159. netscale = 4
  160. file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
  161. elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
  162. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
  163. netscale = 2
  164. file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
  165. elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
  166. model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
  167. netscale = 4
  168. file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
  169. elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
  170. model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
  171. netscale = 4
  172. file_url = [
  173. 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
  174. 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
  175. ]
  176. # ---------------------- determine model paths ---------------------- #
  177. model_path = os.path.join('weights', args.model_name + '.pth')
  178. if not os.path.isfile(model_path):
  179. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  180. for url in file_url:
  181. # model_path will be updated
  182. model_path = load_file_from_url(
  183. url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
  184. # use dni to control the denoise strength
  185. dni_weight = None
  186. if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
  187. wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
  188. model_path = [model_path, wdn_model_path]
  189. dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
  190. # restorer
  191. upsampler = RealESRGANer(
  192. scale=netscale,
  193. model_path=model_path,
  194. dni_weight=dni_weight,
  195. model=model,
  196. tile=args.tile,
  197. tile_pad=args.tile_pad,
  198. pre_pad=args.pre_pad,
  199. half=not args.fp32,
  200. device=device,
  201. )
  202. if 'anime' in args.model_name and args.face_enhance:
  203. print('face_enhance is not supported in anime models, we turned this option off for you. '
  204. 'if you insist on turning it on, please manually comment the relevant lines of code.')
  205. args.face_enhance = False
  206. if args.face_enhance: # Use GFPGAN for face enhancement
  207. from gfpgan import GFPGANer
  208. face_enhancer = GFPGANer(
  209. model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
  210. upscale=args.outscale,
  211. arch='clean',
  212. channel_multiplier=2,
  213. bg_upsampler=upsampler) # TODO support custom device
  214. else:
  215. face_enhancer = None
  216. reader = Reader(args, total_workers, worker_idx)
  217. audio = reader.get_audio()
  218. height, width = reader.get_resolution()
  219. fps = reader.get_fps()
  220. writer = Writer(args, audio, height, width, video_save_path, fps)
  221. pbar = tqdm(total=len(reader), unit='frame', desc='inference')
  222. while True:
  223. img = reader.get_frame()
  224. if img is None:
  225. break
  226. try:
  227. if args.face_enhance:
  228. _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
  229. else:
  230. output, _ = upsampler.enhance(img, outscale=args.outscale)
  231. except RuntimeError as error:
  232. print('Error', error)
  233. print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
  234. else:
  235. writer.write_frame(output)
  236. torch.cuda.synchronize(device)
  237. pbar.update(1)
  238. reader.close()
  239. writer.close()
  240. def run(args):
  241. args.video_name = osp.splitext(os.path.basename(args.input))[0]
  242. video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
  243. if args.extract_frame_first:
  244. tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
  245. os.makedirs(tmp_frames_folder, exist_ok=True)
  246. os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
  247. args.input = tmp_frames_folder
  248. num_gpus = torch.cuda.device_count()
  249. num_process = num_gpus * args.num_process_per_gpu
  250. if num_process == 1:
  251. inference_video(args, video_save_path)
  252. return
  253. ctx = torch.multiprocessing.get_context('spawn')
  254. pool = ctx.Pool(num_process)
  255. os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
  256. pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
  257. for i in range(num_process):
  258. sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
  259. pool.apply_async(
  260. inference_video,
  261. args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
  262. callback=lambda arg: pbar.update(1))
  263. pool.close()
  264. pool.join()
  265. # combine sub videos
  266. # prepare vidlist.txt
  267. with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
  268. for i in range(num_process):
  269. f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
  270. cmd = [
  271. args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
  272. 'copy', f'{video_save_path}'
  273. ]
  274. print(' '.join(cmd))
  275. subprocess.call(cmd)
  276. shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
  277. if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
  278. shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
  279. os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
  280. def main():
  281. """Inference demo for Real-ESRGAN.
  282. It mainly for restoring anime videos.
  283. """
  284. parser = argparse.ArgumentParser()
  285. parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder')
  286. parser.add_argument(
  287. '-n',
  288. '--model_name',
  289. type=str,
  290. default='realesr-animevideov3',
  291. help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |'
  292. ' RealESRGAN_x2plus | realesr-general-x4v3'
  293. 'Default:realesr-animevideov3'))
  294. parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
  295. parser.add_argument(
  296. '-dn',
  297. '--denoise_strength',
  298. type=float,
  299. default=0.5,
  300. help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
  301. 'Only used for the realesr-general-x4v3 model'))
  302. parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
  303. parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
  304. parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
  305. parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
  306. parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
  307. parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
  308. parser.add_argument(
  309. '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
  310. parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
  311. parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
  312. parser.add_argument('--extract_frame_first', action='store_true')
  313. parser.add_argument('--num_process_per_gpu', type=int, default=1)
  314. parser.add_argument(
  315. '--alpha_upsampler',
  316. type=str,
  317. default='realesrgan',
  318. help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
  319. parser.add_argument(
  320. '--ext',
  321. type=str,
  322. default='auto',
  323. help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
  324. args = parser.parse_args()
  325. args.input = args.input.rstrip('/').rstrip('\\')
  326. os.makedirs(args.output, exist_ok=True)
  327. if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
  328. is_video = True
  329. else:
  330. is_video = False
  331. if is_video and args.input.endswith('.flv'):
  332. mp4_path = args.input.replace('.flv', '.mp4')
  333. os.system(f'ffmpeg -i {args.input} -codec copy {mp4_path}')
  334. args.input = mp4_path
  335. if args.extract_frame_first and not is_video:
  336. args.extract_frame_first = False
  337. run(args)
  338. if args.extract_frame_first:
  339. tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
  340. shutil.rmtree(tmp_frames_folder)
  341. if __name__ == '__main__':
  342. main()