Browse Source

add SRVGGNetCompact arch, update inference

Xintao 3 years ago
parent
commit
696e1a6741
7 changed files with 139 additions and 62 deletions
  1. 0 4
      FAQ.md
  2. 10 10
      README.md
  3. 10 11
      README_CN.md
  4. 8 7
      docs/anime_model.md
  5. 41 25
      inference_realesrgan.py
  6. 69 0
      realesrgan/archs/srvgg_arch.py
  7. 1 5
      realesrgan/utils.py

+ 0 - 4
FAQ.md

@@ -1,9 +1,5 @@
 # FAQ
 
-1. **What is the difference of `--netscale` and `outscale`?**
-
-A: TODO.
-
 1. **How to select models?**
 
 A: TODO.

+ 10 - 10
README.md

@@ -166,7 +166,7 @@ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_
 Inference!
 
 ```bash
-python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs --face_enhance
+python inference_realesrgan.py -n RealESRGAN_x4plus -i inputs --face_enhance
 ```
 
 Results are in the `results` folder
@@ -184,7 +184,7 @@ Pre-trained models: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real
 # download model
 wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
 # inference
-python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth --input inputs
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
 ```
 
 Results are in the `results` folder
@@ -194,23 +194,23 @@ Results are in the `results` folder
 1. You can use X4 model for **arbitrary output size** with the argument `outscale`. The program will further perform cheap resize operation after the Real-ESRGAN output.
 
 ```console
-Usage: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --output outfile [options]...
+Usage: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile -o outfile [options]...
 
-A common command: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --netscale 4 --outscale 3.5 --half --face_enhance
+A common command: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile --outscale 3.5 --half --face_enhance
 
   -h                   show this help
-  --input              Input image or folder. Default: inputs
-  --output             Output folder. Default: results
-  --model_path         Path to the pre-trained model. Default: experiments/pretrained_models/RealESRGAN_x4plus.pth
-  --netscale           Upsample scale factor of the network. Default: 4
-  --outscale           The final upsampling scale of the image. Default: 4
+  -i --input           Input image or folder. Default: inputs
+  -o --output          Output folder. Default: results
+  -n --model_name      Model name. Default: RealESRGAN_x4plus
+  -s, --outscale       The final upsampling scale of the image. Default: 4
   --suffix             Suffix of the restored image. Default: out
-  --tile               Tile size, 0 for no tile during testing. Default: 0
+  -t, --tile           Tile size, 0 for no tile during testing. Default: 0
   --face_enhance       Whether to use GFPGAN to enhance face. Default: False
   --half               Whether to use half precision during inference. Default: False
   --ext                Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
 ```
 
+
 ## :european_castle: Model Zoo
 
 - [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth): X4 model for general images

+ 10 - 11
README_CN.md

@@ -162,7 +162,7 @@ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_
 推断!
 
 ```bash
-python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input inputs --face_enhance
+python inference_realesrgan.py -n RealESRGAN_x4plus -i inputs --face_enhance
 ```
 
 结果在`results`文件夹
@@ -180,28 +180,27 @@ python inference_realesrgan.py --model_path experiments/pretrained_models/RealES
 # 下载模型
 wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
 # 推断
-python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth --input inputs
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
 ```
 
 结果在`results`文件夹
 
 ### Python 脚本的用法
 
-1. 虽然你用了 X4 模型,但是你可以 **输出任意尺寸比例的图片**,只要实用了 `outscale` 参数. 程序会进一步对模型的输出图像进行缩放。
+1. 虽然你使用了 X4 模型,但是你可以 **输出任意尺寸比例的图片**,只要实用了 `outscale` 参数. 程序会进一步对模型的输出图像进行缩放。
 
 ```console
-Usage: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --output outfile [options]...
+Usage: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile -o outfile [options]...
 
-A common command: python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input infile --netscale 4 --outscale 3.5 --half --face_enhance
+A common command: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile --outscale 3.5 --half --face_enhance
 
   -h                   show this help
-  --input              Input image or folder. Default: inputs
-  --output             Output folder. Default: results
-  --model_path         Path to the pre-trained model. Default: experiments/pretrained_models/RealESRGAN_x4plus.pth
-  --netscale           Upsample scale factor of the network. Default: 4
-  --outscale           The final upsampling scale of the image. Default: 4
+  -i --input           Input image or folder. Default: inputs
+  -o --output          Output folder. Default: results
+  -n --model_name      Model name. Default: RealESRGAN_x4plus
+  -s, --outscale       The final upsampling scale of the image. Default: 4
   --suffix             Suffix of the restored image. Default: out
-  --tile               Tile size, 0 for no tile during testing. Default: 0
+  -t, --tile           Tile size, 0 for no tile during testing. Default: 0
   --face_enhance       Whether to use GFPGAN to enhance face. Default: False
   --half               Whether to use half precision during inference. Default: False
   --ext                Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto

+ 8 - 7
docs/anime_model.md

@@ -1,12 +1,13 @@
-# Anime model
+# Anime Model
 
 :white_check_mark: We add [*RealESRGAN_x4plus_anime_6B.pth*](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth), which is optimized for **anime** images with much smaller model size.
 
-- [How to Use](#How-to-Use)
-  - [PyTorch Inference](#PyTorch-Inference)
-  - [ncnn Executable File](#ncnn-Executable-File)
-- [Comparisons with waifu2x](#Comparisons-with-waifu2x)
-- [Comparisons with Sliding Bars](#Comparions-with-Sliding-Bars)
+- [Anime Model](#anime-model)
+  - [How to Use](#how-to-use)
+    - [PyTorch Inference](#pytorch-inference)
+    - [ncnn Executable File](#ncnn-executable-file)
+  - [Comparisons with waifu2x](#comparisons-with-waifu2x)
+  - [Comparisons with Sliding Bars](#comparisons-with-sliding-bars)
 
 <p align="center">
   <img src="https://raw.githubusercontent.com/xinntao/public-figures/master/Real-ESRGAN/cmp_realesrgan_anime_1.png">
@@ -26,7 +27,7 @@ Pre-trained models: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real
 # download model
 wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
 # inference
-python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth --input inputs
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
 ```
 
 ### ncnn Executable File

+ 41 - 25
inference_realesrgan.py

@@ -5,28 +5,30 @@ import os
 from basicsr.archs.rrdbnet_arch import RRDBNet
 
 from realesrgan import RealESRGANer
+from realesrgan.archs.srvgg_arch import SRVGGNetCompact
 
 
 def main():
     """Inference demo for Real-ESRGAN.
     """
     parser = argparse.ArgumentParser()
-    parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
+    parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
     parser.add_argument(
-        '--model_path',
+        '-n',
+        '--model_name',
         type=str,
-        default='experiments/pretrained_models/RealESRGAN_x4plus.pth',
-        help='Path to the pre-trained model')
-    parser.add_argument('--output', type=str, default='results', help='Output folder')
-    parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
-    parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
+        default='RealESRGAN_x4plus',
+        help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus'
+              'RealESRGANv2-anime-xsx2 | RealESRGANv2-animevideo-xsx2-nousm | RealESRGANv2-animevideo-xsx2'
+              'RealESRGANv2-anime-xsx4 | RealESRGANv2-animevideo-xsx4-nousm | RealESRGANv2-animevideo-xsx4'))
+    parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
+    parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
     parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
-    parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
+    parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
     parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
     parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
     parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
     parser.add_argument('--half', action='store_true', help='Use half precision during inference')
-    parser.add_argument('--block', type=int, default=23, help='num_block in RRDB')
     parser.add_argument(
         '--alpha_upsampler',
         type=str,
@@ -39,16 +41,39 @@ def main():
         help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
     args = parser.parse_args()
 
-    if 'RealESRGAN_x4plus_anime_6B.pth' in args.model_path:
-        args.block = 6
-    elif 'RealESRGAN_x2plus.pth' in args.model_path:
-        args.netscale = 2
+    # determine models according to model names
+    args.model_name = args.model_name.split('.')[0]
+    if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']:  # x4 RRDBNet model
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+        netscale = 4
+    elif args.model_name in ['RealESRGAN_x4plus_anime_6B']:  # x4 RRDBNet model with 6 blocks
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+        netscale = 4
+    elif args.model_name in ['RealESRGAN_x2plus']:  # x2 RRDBNet model
+        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+        netscale = 2
+    elif args.model_name in [
+            'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
+    ]:  # x2 VGG-style model (XS size)
+        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
+        netscale = 2
+    elif args.model_name in [
+            'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
+    ]:  # x4 VGG-style model (XS size)
+        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+        netscale = 4
 
-    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=args.block, num_grow_ch=32, scale=args.netscale)
+    # determine model paths
+    model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
+    if not os.path.isfile(model_path):
+        model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
+    if not os.path.isfile(model_path):
+        raise ValueError(f'Model {args.model_name} does not exist.')
 
+    # restorer
     upsampler = RealESRGANer(
-        scale=args.netscale,
-        model_path=args.model_path,
+        scale=netscale,
+        model_path=model_path,
         model=model,
         tile=args.tile,
         tile_pad=args.tile_pad,
@@ -80,15 +105,6 @@ def main():
         else:
             img_mode = None
 
-        # give warnings for too large/small images
-        h, w = img.shape[0:2]
-        if max(h, w) > 1000 and args.netscale == 4:
-            import warnings
-            warnings.warn('The input image is large, try X2 model for better performance.')
-        if max(h, w) < 500 and args.netscale == 2:
-            import warnings
-            warnings.warn('The input image is small, try X4 model for better performance.')
-
         try:
             if args.face_enhance:
                 _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)

+ 69 - 0
realesrgan/archs/srvgg_arch.py

@@ -0,0 +1,69 @@
+from basicsr.utils.registry import ARCH_REGISTRY
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@ARCH_REGISTRY.register()
+class SRVGGNetCompact(nn.Module):
+    """A compact VGG-style network structure for super-resolution.
+
+    It is a compact network structure, which performs upsampling in the last layer and no convolution is
+    conducted on the HR feature space.
+
+    Args:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_out_ch (int): Channel number of outputs. Default: 3.
+        num_feat (int): Channel number of intermediate features. Default: 64.
+        num_conv (int): Number of convolution layers in the body network. Default: 16.
+        upscale (int): Upsampling factor. Default: 4.
+        act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
+    """
+
+    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+        super(SRVGGNetCompact, self).__init__()
+        self.num_in_ch = num_in_ch
+        self.num_out_ch = num_out_ch
+        self.num_feat = num_feat
+        self.num_conv = num_conv
+        self.upscale = upscale
+        self.act_type = act_type
+
+        self.body = nn.ModuleList()
+        # the first conv
+        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+        # the first activation
+        if act_type == 'relu':
+            activation = nn.ReLU(inplace=True)
+        elif act_type == 'prelu':
+            activation = nn.PReLU(num_parameters=num_feat)
+        elif act_type == 'leakyrelu':
+            activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        self.body.append(activation)
+
+        # the body structure
+        for _ in range(num_conv):
+            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+            # activation
+            if act_type == 'relu':
+                activation = nn.ReLU(inplace=True)
+            elif act_type == 'prelu':
+                activation = nn.PReLU(num_parameters=num_feat)
+            elif act_type == 'leakyrelu':
+                activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+            self.body.append(activation)
+
+        # the last conv
+        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+        # upsample
+        self.upsampler = nn.PixelShuffle(upscale)
+
+    def forward(self, x):
+        out = x
+        for i in range(0, len(self.body)):
+            out = self.body[i](out)
+
+        out = self.upsampler(out)
+        # add the nearest upsampled image, so that the network learns the residual
+        base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+        out += base
+        return out

+ 1 - 5
realesrgan/utils.py

@@ -3,7 +3,6 @@ import math
 import numpy as np
 import os
 import torch
-from basicsr.archs.rrdbnet_arch import RRDBNet
 from basicsr.utils.download_util import load_file_from_url
 from torch.nn import functional as F
 
@@ -16,7 +15,7 @@ class RealESRGANer():
     Args:
         scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
         model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
-        model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
+        model (nn.Module): The defined network. Default: None.
         tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
             input images into tiles, and then process each of them. Finally, they will be merged into one image.
             0 denotes for do not use tile. Default: 0.
@@ -35,9 +34,6 @@ class RealESRGANer():
 
         # initialize model
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-        if model is None:
-            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
-
         # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
         if model_path.startswith('https://'):
             model_path = load_file_from_url(