Browse Source

fix colorspace bug & support multi-gpu and multi-processing (#312)

* fix colorspace bug of ffmpeg stream, add multi-gpu and multi-processing suport for inference_realesrgan_video.py

* fix code format

Co-authored-by: yanzewu <yanzewu@tencent.com>
wyz 3 years ago
parent
commit
8cb9bd403e
3 changed files with 294 additions and 229 deletions
  1. 11 4
      docs/anime_video_model.md
  2. 281 223
      inference_realesrgan_video.py
  3. 2 2
      realesrgan/utils.py

+ 11 - 4
docs/anime_video_model.md

@@ -35,13 +35,20 @@ The following are some demos (best view in the full screen mode).
 ```bash
 # download model
 wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights
-# inference
-python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --stream
+# single gpu and single process inference
+CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
+# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
+CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
+# multi gpu and multi process inference
+CUDA_VISIBLE_DEVICES=0,1,2,3 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
 ```
 ```console
 Usage:
---stream                 with this option, the enhanced frames are sent directly to a ffmpeg stream,
-                         avoiding storing large (usually tens of GB) intermediate results.        
+--num_process_per_gpu    The total number of process is num_gpu * num_process_per_gpu. The bottleneck of
+                         the program lies on the IO, so the GPUs are usually not fully utilized. To alleviate
+                         this issue, you can use multi-processing by setting this parameter. As long as it
+                         does not exceed the CUDA memory
+--extract_frame_first    If you encounter ffmpeg error when using multi-processing, you can turn this option on.
 ```
 
 ### NCNN Executable File

+ 281 - 223
inference_realesrgan_video.py

@@ -4,159 +4,235 @@ import glob
 import mimetypes
 import numpy as np
 import os
-import queue
 import shutil
+import subprocess
 import torch
 from basicsr.archs.rrdbnet_arch import RRDBNet
-from basicsr.utils.logger import AvgTimer
+from os import path as osp
 from tqdm import tqdm
 
-from realesrgan import IOConsumer, PrefetchReader, RealESRGANer
+from realesrgan import RealESRGANer
 from realesrgan.archs.srvgg_arch import SRVGGNetCompact
 
+try:
+    import ffmpeg
+except ImportError:
+    import pip
+    pip.main(['install', '--user', 'ffmpeg-python'])
+    import ffmpeg
+
+
+def get_video_meta_info(video_path):
+    ret = {}
+    probe = ffmpeg.probe(video_path)
+    video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
+    has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
+    ret['width'] = video_streams[0]['width']
+    ret['height'] = video_streams[0]['height']
+    ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
+    ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
+    ret['nb_frames'] = int(video_streams[0]['nb_frames'])
+    return ret
+
+
+def get_sub_video(args, num_process, process_idx):
+    if num_process == 1:
+        return args.input
+    meta = get_video_meta_info(args.input)
+    duration = int(meta['nb_frames'] / meta['fps'])
+    part_time = duration // num_process
+    print(f'duration: {duration}, part_time: {part_time}')
+    os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
+    out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
+    cmd = [
+        args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
+        f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
+    ]
+    print(' '.join(cmd))
+    subprocess.call(' '.join(cmd), shell=True)
+    return out_path
+
+
+class Reader:
+
+    def __init__(self, args, total_workers=1, worker_idx=0):
+        self.args = args
+        input_type = mimetypes.guess_type(args.input)[0]
+        self.input_type = 'folder' if input_type is None else input_type
+        self.paths = []  # for image&folder type
+        self.audio = None
+        self.input_fps = None
+        if self.input_type.startswith('video'):
+            video_path = get_sub_video(args, total_workers, worker_idx)
+            self.stream_reader = (
+                ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
+                                                loglevel='error').run_async(
+                                                    pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+            meta = get_video_meta_info(video_path)
+            self.width = meta['width']
+            self.height = meta['height']
+            self.input_fps = meta['fps']
+            self.audio = meta['audio']
+            self.nb_frames = meta['nb_frames']
 
-def get_frames(args, extract_frames=False):
-    # input can be a video file / a folder of frames / an image
-    is_video = False
-    if mimetypes.guess_type(args.input)[0].startswith('video'):  # is a video file
-        is_video = True
-        video_name = os.path.splitext(os.path.basename(args.input))[0]
-        if extract_frames:
-            frame_folder = os.path.join('tmp_frames', video_name)
-            os.makedirs(frame_folder, exist_ok=True)
-            # use ffmpeg to extract frames
-            os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0  {frame_folder}/frame%08d.png')
-            # get image path list
-            paths = sorted(glob.glob(os.path.join(frame_folder, '*')))
         else:
-            paths = []
-        # get input video fps
-        if args.fps is None:
-            import ffmpeg
-            probe = ffmpeg.probe(args.input)
-            video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
-            args.fps = eval(video_streams[0]['avg_frame_rate'])
-    elif mimetypes.guess_type(args.input)[0].startswith('image'):  # is an image file
-        paths = [args.input]
-    else:
-        paths = sorted(glob.glob(os.path.join(args.input, '*')))
-        assert len(paths) > 0, 'the input folder is empty'
-
-    if args.fps is None:
-        args.fps = 24
-
-    return is_video, paths
-
-
-def inference_stream(args, upsampler, face_enhancer):
-    try:
-        import ffmpeg
-    except ImportError:
-        import pip
-        pip.main(['install', '--user', 'ffmpeg-python'])
-        import ffmpeg
-
-    is_video, paths = get_frames(args, extract_frames=False)
-    video_name = os.path.splitext(os.path.basename(args.input))[0]
-    video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
-
-    # decoder
-    if is_video:
-        # get height and width
-        probe = ffmpeg.probe(args.input)
-        video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
-        width = video_streams[0]['width']
-        height = video_streams[0]['height']
-
-        # set up frame decoder
-        decoder = (
-            ffmpeg.input(args.input).output('pipe:', format='rawvideo', pix_fmt='rgb24', loglevel='warning').run_async(
-                pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+            if self.input_type.startswith('image'):
+                self.paths = [args.input]
+            else:
+                paths = sorted(glob.glob(os.path.join(args.input, '*')))
+                tot_frames = len(paths)
+                num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
+                self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
+
+            self.nb_frames = len(self.paths)
+            assert self.nb_frames > 0, 'empty folder'
+            from PIL import Image
+            tmp_img = Image.open(self.paths[0])
+            self.width, self.height = tmp_img.size
+        self.idx = 0
+
+    def get_resolution(self):
+        return self.height, self.width
+
+    def get_fps(self):
+        if self.args.fps is not None:
+            return self.args.fps
+        elif self.input_fps is not None:
+            return self.input_fps
+        return 24
+
+    def get_audio(self):
+        return self.audio
+
+    def __len__(self):
+        return self.nb_frames
+
+    def get_frame_from_stream(self):
+        img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3)  # 3 bytes for one pixel
+        if not img_bytes:
+            return None
+        img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
+        return img
+
+    def get_frame_from_list(self):
+        if self.idx >= self.nb_frames:
+            return None
+        img = cv2.imread(self.paths[self.idx])
+        self.idx += 1
+        return img
+
+    def get_frame(self):
+        if self.input_type.startswith('video'):
+            return self.get_frame_from_stream()
+        else:
+            return self.get_frame_from_list()
+
+    def close(self):
+        if self.input_type.startswith('video'):
+            self.stream_reader.stdin.close()
+            self.stream_reader.wait()
+
+
+class Writer:
+
+    def __init__(self, args, audio, height, width, video_save_path, fps):
+        out_width, out_height = int(width * args.outscale), int(height * args.outscale)
+        if out_height > 2160:
+            print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
+                  'We highly recommend to decrease the outscale(aka, -s).')
+
+        if audio is not None:
+            self.stream_writer = (
+                ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
+                             framerate=fps).output(
+                                 audio,
+                                 video_save_path,
+                                 pix_fmt='yuv420p',
+                                 vcodec='libx264',
+                                 loglevel='error',
+                                 acodec='copy').overwrite_output().run_async(
+                                     pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+        else:
+            self.stream_writer = (
+                ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
+                             framerate=fps).output(
+                                 video_save_path, pix_fmt='yuv420p', vcodec='libx264',
+                                 loglevel='error').overwrite_output().run_async(
+                                     pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+
+    def write_frame(self, frame):
+        frame = frame.astype(np.uint8).tobytes()
+        self.stream_writer.stdin.write(frame)
+
+    def close(self):
+        self.stream_writer.stdin.close()
+        self.stream_writer.wait()
+
+
+def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
+    # ---------------------- determine models according to model names ---------------------- #
+    args.model_name = args.model_name.split('.pth')[0]
+    if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']:  # x4 RRDBNet model
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+        netscale = 4
+    elif args.model_name in ['RealESRGAN_x4plus_anime_6B']:  # x4 RRDBNet model with 6 blocks
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+        netscale = 4
+    elif args.model_name in ['RealESRGAN_x2plus']:  # x2 RRDBNet model
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+        netscale = 2
+    elif args.model_name in ['realesr-animevideov3']:  # x4 VGG-style model (XS size)
+        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+        netscale = 4
     else:
-        from PIL import Image
-        tmp_img = Image.open(paths[0])
-        width, height = tmp_img.size
-        idx = 0
-
-    out_width, out_height = int(width * args.outscale), int(height * args.outscale)
-    if out_height > 2160:
-        print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
-              'We highly recommend to decrease the outscale(aka, -s).')
-    # encoder
-    if is_video:
-        audio = ffmpeg.input(args.input).audio
-        encoder = (
-            ffmpeg.input(
-                'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=args.fps).output(
-                    audio, video_save_path, pix_fmt='yuv420p', vcodec='libx264', loglevel='info',
-                    acodec='copy').overwrite_output().run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+        raise NotImplementedError
+
+    # ---------------------- determine model paths ---------------------- #
+    model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
+    if not os.path.isfile(model_path):
+        model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
+    if not os.path.isfile(model_path):
+        raise ValueError(f'Model {args.model_name} does not exist.')
+
+    # restorer
+    upsampler = RealESRGANer(
+        scale=netscale,
+        model_path=model_path,
+        model=model,
+        tile=args.tile,
+        tile_pad=args.tile_pad,
+        pre_pad=args.pre_pad,
+        half=not args.fp32,
+        device=device,
+    )
+
+    if 'anime' in args.model_name and args.face_enhance:
+        print('face_enhance is not supported in anime models, we turned this option off for you. '
+              'if you insist on turning it on, please manually comment the relevant lines of code.')
+        args.face_enhance = False
+
+    if args.face_enhance:  # Use GFPGAN for face enhancement
+        from gfpgan import GFPGANer
+        face_enhancer = GFPGANer(
+            model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
+            upscale=args.outscale,
+            arch='clean',
+            channel_multiplier=2,
+            bg_upsampler=upsampler)  # TODO support custom device
     else:
-        encoder = (
-            ffmpeg.input(
-                'pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}',
-                framerate=args.fps).output(video_save_path, pix_fmt='yuv420p', vcodec='libx264',
-                                           loglevel='info').overwrite_output().run_async(
-                                               pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+        face_enhancer = None
 
-    while True:
-        if is_video:
-            img_bytes = decoder.stdout.read(width * height * 3)  # 3 bytes for one pixel
-            if not img_bytes:
-                break
-            img = np.frombuffer(img_bytes, np.uint8).reshape([height, width, 3])
-        else:
-            if idx >= len(paths):
-                break
-            img = cv2.imread(paths[idx])
-            idx += 1
+    reader = Reader(args, total_workers, worker_idx)
+    audio = reader.get_audio()
+    height, width = reader.get_resolution()
+    fps = reader.get_fps()
+    writer = Writer(args, audio, height, width, video_save_path, fps)
 
-        try:
-            if args.face_enhance:
-                _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
-            else:
-                output, _ = upsampler.enhance(img, outscale=args.outscale)
-        except RuntimeError as error:
-            print('Error', error)
-            print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
-        else:
-            output = output.astype(np.uint8).tobytes()
-            encoder.stdin.write(output)
-
-        torch.cuda.synchronize()
-
-    if is_video:
-        decoder.stdin.close()
-        decoder.wait()
-    encoder.stdin.close()
-    encoder.wait()
-
-
-def inference_frames(args, upsampler, face_enhancer):
-    is_video, paths = get_frames(args, extract_frames=True)
-    video_name = os.path.splitext(os.path.basename(args.input))[0]
-
-    # for saving restored frames
-    save_frame_folder = os.path.join(args.output, video_name, 'frames_tmpout')
-    os.makedirs(save_frame_folder, exist_ok=True)
-
-    timer = AvgTimer()
-    timer.start()
-    pbar = tqdm(total=len(paths), unit='frame', desc='inference')
-    # set up prefetch reader
-    reader = PrefetchReader(paths, num_prefetch_queue=4)
-    reader.start()
-
-    que = queue.Queue()
-    consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.consumer)]
-    for consumer in consumers:
-        consumer.start()
-
-    for idx, (path, img) in enumerate(zip(paths, reader)):
-        imgname, extension = os.path.splitext(os.path.basename(path))
-        if len(img.shape) == 3 and img.shape[2] == 4:
-            img_mode = 'RGBA'
-        else:
-            img_mode = None
+    pbar = tqdm(total=len(reader), unit='frame', desc='inference')
+    while True:
+        img = reader.get_frame()
+        if img is None:
+            break
 
         try:
             if args.face_enhance:
@@ -166,39 +242,61 @@ def inference_frames(args, upsampler, face_enhancer):
         except RuntimeError as error:
             print('Error', error)
             print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
-
         else:
-            if args.ext == 'auto':
-                extension = extension[1:]
-            else:
-                extension = args.ext
-            if img_mode == 'RGBA':  # RGBA images should be saved in png format
-                extension = 'png'
-            save_path = os.path.join(save_frame_folder, f'{imgname}_out.{extension}')
-
-            que.put({'output': output, 'save_path': save_path})
+            writer.write_frame(output)
 
+        torch.cuda.synchronize(device)
         pbar.update(1)
-        torch.cuda.synchronize()
-        timer.record()
-        avg_fps = 1. / (timer.get_avg_time() + 1e-7)
-        pbar.set_description(f'idx {idx}, fps {avg_fps:.2f}')
-
-    for _ in range(args.consumer):
-        que.put('quit')
-    for consumer in consumers:
-        consumer.join()
-    pbar.close()
-
-    # merge frames to video
-    video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
-    os.system(f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} -i {args.input}'
-              f' -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r {args.fps} -pix_fmt yuv420p  {video_save_path}')
-    # delete tmp file
-    shutil.rmtree(save_frame_folder)
-    frame_folder = os.path.join('tmp_frames', video_name)
-    if os.path.isdir(frame_folder):
-        shutil.rmtree(frame_folder)
+
+    reader.close()
+    writer.close()
+
+
+def run(args):
+    args.video_name = osp.splitext(os.path.basename(args.input))[0]
+    video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
+
+    if args.extract_frame_first:
+        tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
+        os.makedirs(tmp_frames_folder, exist_ok=True)
+        os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0  {tmp_frames_folder}/frame%08d.png')
+        args.input = tmp_frames_folder
+
+    num_gpus = torch.cuda.device_count()
+    num_process = num_gpus * args.num_process_per_gpu
+    if num_process == 1:
+        inference_video(args, video_save_path)
+        return
+
+    ctx = torch.multiprocessing.get_context('spawn')
+    pool = ctx.Pool(num_process)
+    os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
+    pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
+    for i in range(num_process):
+        sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
+        pool.apply_async(
+            inference_video,
+            args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
+            callback=lambda arg: pbar.update(1))
+    pool.close()
+    pool.join()
+
+    # combine sub videos
+    # prepare vidlist.txt
+    with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
+        for i in range(num_process):
+            f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
+
+    cmd = [
+        args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
+        'copy', f'{video_save_path}'
+    ]
+    print(' '.join(cmd))
+    subprocess.call(cmd)
+    shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
+    if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
+        shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
+    os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
 
 
 def main():
@@ -226,9 +324,9 @@ def main():
     parser.add_argument(
         '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
     parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
-    parser.add_argument('--consumer', type=int, default=4, help='Number of IO consumers')
-    parser.add_argument('--stream', action='store_true')
     parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
+    parser.add_argument('--extract_frame_first', action='store_true')
+    parser.add_argument('--num_process_per_gpu', type=int, default=1)
 
     parser.add_argument(
         '--alpha_upsampler',
@@ -243,61 +341,21 @@ def main():
     args = parser.parse_args()
 
     args.input = args.input.rstrip('/').rstrip('\\')
+    os.makedirs(args.output, exist_ok=True)
 
-    # ---------------------- determine models according to model names ---------------------- #
-    args.model_name = args.model_name.split('.pth')[0]
-    if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']:  # x4 RRDBNet model
-        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
-        netscale = 4
-    elif args.model_name in ['RealESRGAN_x4plus_anime_6B']:  # x4 RRDBNet model with 6 blocks
-        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
-        netscale = 4
-    elif args.model_name in ['RealESRGAN_x2plus']:  # x2 RRDBNet model
-        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
-        netscale = 2
-    elif args.model_name in ['realesr-animevideov3']:  # x4 VGG-style model (XS size)
-        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
-        netscale = 4
-
-    # ---------------------- determine model paths ---------------------- #
-    model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
-    if not os.path.isfile(model_path):
-        model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
-    if not os.path.isfile(model_path):
-        raise ValueError(f'Model {args.model_name} does not exist.')
-
-    # restorer
-    upsampler = RealESRGANer(
-        scale=netscale,
-        model_path=model_path,
-        model=model,
-        tile=args.tile,
-        tile_pad=args.tile_pad,
-        pre_pad=args.pre_pad,
-        half=not args.fp32)
-
-    if 'anime' in args.model_name and args.face_enhance:
-        print('face_enhance is not supported in anime models, we turned this option off for you. '
-              'if you insist on turning it on, please manually comment the relevant lines of code.')
-        args.face_enhance = False
-
-    if args.face_enhance:  # Use GFPGAN for face enhancement
-        from gfpgan import GFPGANer
-        face_enhancer = GFPGANer(
-            model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
-            upscale=args.outscale,
-            arch='clean',
-            channel_multiplier=2,
-            bg_upsampler=upsampler)
+    if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
+        is_video = True
     else:
-        face_enhancer = None
+        is_video = False
 
-    os.makedirs(args.output, exist_ok=True)
+    if args.extract_frame_first and not is_video:
+        args.extract_frame_first = False
 
-    if args.stream:
-        inference_stream(args, upsampler, face_enhancer)
-    else:
-        inference_frames(args, upsampler, face_enhancer)
+    run(args)
+
+    if args.extract_frame_first:
+        tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
+        shutil.rmtree(tmp_frames_folder)
 
 
 if __name__ == '__main__':

+ 2 - 2
realesrgan/utils.py

@@ -26,7 +26,7 @@ class RealESRGANer():
         half (float): Whether to use half precision during inference. Default: False.
     """
 
-    def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
+    def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
         self.scale = scale
         self.tile_size = tile
         self.tile_pad = tile_pad
@@ -35,7 +35,7 @@ class RealESRGANer():
         self.half = half
 
         # initialize model
-        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
         # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
         if model_path.startswith('https://'):
             model_path = load_file_from_url(