Browse Source

add unittest for model and utils

Xintao 3 years ago
parent
commit
42110857ef

+ 2 - 25
realesrgan/utils.py

@@ -4,9 +4,8 @@ import numpy as np
 import os
 import torch
 from basicsr.archs.rrdbnet_arch import RRDBNet
-from torch.hub import download_url_to_file, get_dir
+from basicsr.utils.download_util import load_file_from_url
 from torch.nn import functional as F
-from urllib.parse import urlparse
 
 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
@@ -42,7 +41,7 @@ class RealESRGANer():
         # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
         if model_path.startswith('https://'):
             model_path = load_file_from_url(
-                url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
+                url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
         loadnet = torch.load(model_path)
         # prefer to use params_ema
         if 'params_ema' in loadnet:
@@ -231,25 +230,3 @@ class RealESRGANer():
                 ), interpolation=cv2.INTER_LANCZOS4)
 
         return output, img_mode
-
-
-def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
-    """Load file form http url, will download models if necessary.
-
-    Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
-    """
-    if model_dir is None:
-        hub_dir = get_dir()
-        model_dir = os.path.join(hub_dir, 'checkpoints')
-
-    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
-
-    parts = urlparse(url)
-    filename = os.path.basename(parts.path)
-    if file_name is not None:
-        filename = file_name
-    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
-    if not os.path.exists(cached_file):
-        print(f'Downloading: "{url}" to {cached_file}\n')
-        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
-    return cached_file

+ 0 - 0
tests/data/demo_option_realesrgan_dataset.yml → tests/data/test_realesrgan_dataset.yml


+ 115 - 0
tests/data/test_realesrgan_model.yml

@@ -0,0 +1,115 @@
+scale: 4
+num_gpu: 1
+manual_seed: 0
+is_train: True
+dist: False
+
+# ----------------- options for synthesizing training data ----------------- #
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 1
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 1
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 1
+resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 1
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 1
+jpeg_range2: [30, 95]
+
+gt_size: 32
+queue_size: 1
+
+# network structures
+network_g:
+  type: RRDBNet
+  num_in_ch: 3
+  num_out_ch: 3
+  num_feat: 4
+  num_block: 1
+  num_grow_ch: 2
+
+network_d:
+  type: UNetDiscriminatorSN
+  num_in_ch: 3
+  num_feat: 2
+  skip_connection: True
+
+# path
+path:
+  pretrain_network_g: ~
+  param_key_g: params_ema
+  strict_load_g: true
+  resume_state: ~
+
+# training settings
+train:
+  ema_decay: 0.999
+  optim_g:
+    type: Adam
+    lr: !!float 1e-4
+    weight_decay: 0
+    betas: [0.9, 0.99]
+  optim_d:
+    type: Adam
+    lr: !!float 1e-4
+    weight_decay: 0
+    betas: [0.9, 0.99]
+
+  scheduler:
+    type: MultiStepLR
+    milestones: [400000]
+    gamma: 0.5
+
+  total_iter: 400000
+  warmup_iter: -1  # no warm up
+
+  # losses
+  pixel_opt:
+    type: L1Loss
+    loss_weight: 1.0
+    reduction: mean
+  # perceptual loss (content and style losses)
+  perceptual_opt:
+    type: PerceptualLoss
+    layer_weights:
+      # before relu
+      'conv1_2': 0.1
+      'conv2_2': 0.1
+      'conv3_4': 1
+      'conv4_4': 1
+      'conv5_4': 1
+    vgg_type: vgg19
+    use_input_norm: true
+    perceptual_weight: !!float 1.0
+    style_weight: 0
+    range_norm: false
+    criterion: l1
+  # gan loss
+  gan_opt:
+    type: GANLoss
+    gan_type: vanilla
+    real_label_val: 1.0
+    fake_label_val: 0.0
+    loss_weight: !!float 1e-1
+
+  net_d_iters: 1
+  net_d_init_iters: 0
+
+
+# validation settings
+val:
+  val_freq: !!float 5e3
+  save_img: False

+ 0 - 0
tests/data/demo_option_realesrgan_paired_dataset.yml → tests/data/test_realesrgan_paired_dataset.yml


+ 75 - 0
tests/data/test_realesrnet_model.yml

@@ -0,0 +1,75 @@
+scale: 4
+num_gpu: 1
+manual_seed: 0
+is_train: True
+dist: False
+
+# ----------------- options for synthesizing training data ----------------- #
+gt_usm: True  # USM the ground-truth
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 1
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 1
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 1
+resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 1
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 1
+jpeg_range2: [30, 95]
+
+gt_size: 32
+queue_size: 1
+
+# network structures
+network_g:
+  type: RRDBNet
+  num_in_ch: 3
+  num_out_ch: 3
+  num_feat: 4
+  num_block: 1
+  num_grow_ch: 2
+
+# path
+path:
+  pretrain_network_g: ~
+  param_key_g: params_ema
+  strict_load_g: true
+  resume_state: ~
+
+# training settings
+train:
+  ema_decay: 0.999
+  optim_g:
+    type: Adam
+    lr: !!float 2e-4
+    weight_decay: 0
+    betas: [0.9, 0.99]
+
+  scheduler:
+    type: MultiStepLR
+    milestones: [1000000]
+    gamma: 0.5
+
+  total_iter: 1000000
+  warmup_iter: -1  # no warm up
+
+  # losses
+  pixel_opt:
+    type: L1Loss
+    loss_weight: 1.0
+    reduction: mean
+
+
+# validation settings
+val:
+  val_freq: !!float 5e3
+  save_img: False

+ 2 - 2
tests/test_dataset.py

@@ -7,7 +7,7 @@ from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
 
 def test_realesrgan_dataset():
 
-    with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
+    with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
         opt = yaml.load(f, Loader=yaml.FullLoader)
 
     dataset = RealESRGANDataset(opt)
@@ -81,7 +81,7 @@ def test_realesrgan_dataset():
 
 def test_realesrgan_paired_dataset():
 
-    with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
+    with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
         opt = yaml.load(f, Loader=yaml.FullLoader)
 
     dataset = RealESRGANPairedDataset(opt)

+ 126 - 0
tests/test_model.py

@@ -0,0 +1,126 @@
+import torch
+import yaml
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.data.paired_image_dataset import PairedImageDataset
+from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
+
+from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
+from realesrgan.models.realesrgan_model import RealESRGANModel
+from realesrgan.models.realesrnet_model import RealESRNetModel
+
+
+def test_realesrnet_model():
+    with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
+        opt = yaml.load(f, Loader=yaml.FullLoader)
+
+    # build model
+    model = RealESRNetModel(opt)
+    # test attributes
+    assert model.__class__.__name__ == 'RealESRNetModel'
+    assert isinstance(model.net_g, RRDBNet)
+    assert isinstance(model.cri_pix, L1Loss)
+    assert isinstance(model.optimizers[0], torch.optim.Adam)
+
+    # prepare data
+    gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+    kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
+    kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
+    sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
+    data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
+    model.feed_data(data)
+    # check dequeue
+    model.feed_data(data)
+    # check data shape
+    assert model.lq.shape == (1, 3, 8, 8)
+    assert model.gt.shape == (1, 3, 32, 32)
+
+    # change probability to test if-else
+    model.opt['gaussian_noise_prob'] = 0
+    model.opt['gray_noise_prob'] = 0
+    model.opt['second_blur_prob'] = 0
+    model.opt['gaussian_noise_prob2'] = 0
+    model.opt['gray_noise_prob2'] = 0
+    model.feed_data(data)
+    # check data shape
+    assert model.lq.shape == (1, 3, 8, 8)
+    assert model.gt.shape == (1, 3, 32, 32)
+
+    # ----------------- test nondist_validation -------------------- #
+    # construct dataloader
+    dataset_opt = dict(
+        name='Demo',
+        dataroot_gt='tests/data/gt',
+        dataroot_lq='tests/data/lq',
+        io_backend=dict(type='disk'),
+        scale=4,
+        phase='val')
+    dataset = PairedImageDataset(dataset_opt)
+    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+    assert model.is_train is True
+    model.nondist_validation(dataloader, 1, None, False)
+    assert model.is_train is True
+
+
+def test_realesrgan_model():
+    with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
+        opt = yaml.load(f, Loader=yaml.FullLoader)
+
+    # build model
+    model = RealESRGANModel(opt)
+    # test attributes
+    assert model.__class__.__name__ == 'RealESRGANModel'
+    assert isinstance(model.net_g, RRDBNet)  # generator
+    assert isinstance(model.net_d, UNetDiscriminatorSN)  # discriminator
+    assert isinstance(model.cri_pix, L1Loss)
+    assert isinstance(model.cri_perceptual, PerceptualLoss)
+    assert isinstance(model.cri_gan, GANLoss)
+    assert isinstance(model.optimizers[0], torch.optim.Adam)
+    assert isinstance(model.optimizers[1], torch.optim.Adam)
+
+    # prepare data
+    gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+    kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
+    kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
+    sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
+    data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
+    model.feed_data(data)
+    # check dequeue
+    model.feed_data(data)
+    # check data shape
+    assert model.lq.shape == (1, 3, 8, 8)
+    assert model.gt.shape == (1, 3, 32, 32)
+
+    # change probability to test if-else
+    model.opt['gaussian_noise_prob'] = 0
+    model.opt['gray_noise_prob'] = 0
+    model.opt['second_blur_prob'] = 0
+    model.opt['gaussian_noise_prob2'] = 0
+    model.opt['gray_noise_prob2'] = 0
+    model.feed_data(data)
+    # check data shape
+    assert model.lq.shape == (1, 3, 8, 8)
+    assert model.gt.shape == (1, 3, 32, 32)
+
+    # ----------------- test nondist_validation -------------------- #
+    # construct dataloader
+    dataset_opt = dict(
+        name='Demo',
+        dataroot_gt='tests/data/gt',
+        dataroot_lq='tests/data/lq',
+        io_backend=dict(type='disk'),
+        scale=4,
+        phase='val')
+    dataset = PairedImageDataset(dataset_opt)
+    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+    assert model.is_train is True
+    model.nondist_validation(dataloader, 1, None, False)
+    assert model.is_train is True
+
+    # ----------------- test optimize_parameters -------------------- #
+    model.feed_data(data)
+    model.optimize_parameters(1)
+    assert model.output.shape == (1, 3, 32, 32)
+    assert isinstance(model.log_dict, dict)
+    # check returned keys
+    expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
+    assert set(expected_keys).issubset(set(model.log_dict.keys()))

+ 87 - 0
tests/test_utils.py

@@ -0,0 +1,87 @@
+import numpy as np
+from basicsr.archs.rrdbnet_arch import RRDBNet
+
+from realesrgan.utils import RealESRGANer
+
+
+def test_realesrganer():
+    # initialize with default model
+    restorer = RealESRGANer(
+        scale=4,
+        model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
+        model=None,
+        tile=10,
+        tile_pad=10,
+        pre_pad=2,
+        half=False)
+    assert isinstance(restorer.model, RRDBNet)
+    assert restorer.half is False
+    # initialize with user-defined model
+    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+    restorer = RealESRGANer(
+        scale=4,
+        model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
+        model=model,
+        tile=10,
+        tile_pad=10,
+        pre_pad=2,
+        half=True)
+    # test attribute
+    assert isinstance(restorer.model, RRDBNet)
+    assert restorer.half is True
+
+    # ------------------ test pre_process ---------------- #
+    img = np.random.random((12, 12, 3)).astype(np.float32)
+    restorer.pre_process(img)
+    assert restorer.img.shape == (1, 3, 14, 14)
+    # with modcrop
+    restorer.scale = 1
+    restorer.pre_process(img)
+    assert restorer.img.shape == (1, 3, 16, 16)
+
+    # ------------------ test process ---------------- #
+    restorer.process()
+    assert restorer.output.shape == (1, 3, 64, 64)
+
+    # ------------------ test post_process ---------------- #
+    restorer.mod_scale = 4
+    output = restorer.post_process()
+    assert output.shape == (1, 3, 60, 60)
+
+    # ------------------ test tile_process ---------------- #
+    restorer.scale = 4
+    img = np.random.random((12, 12, 3)).astype(np.float32)
+    restorer.pre_process(img)
+    restorer.tile_process()
+    assert restorer.output.shape == (1, 3, 64, 64)
+
+    # ------------------ test enhance ---------------- #
+    img = np.random.random((12, 12, 3)).astype(np.float32)
+    result = restorer.enhance(img, outscale=2)
+    assert result[0].shape == (24, 24, 3)
+    assert result[1] == 'RGB'
+
+    # ------------------ test enhance with 16-bit image---------------- #
+    img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
+    result = restorer.enhance(img, outscale=2)
+    assert result[0].shape == (8, 8, 3)
+    assert result[1] == 'RGB'
+
+    # ------------------ test enhance with gray image---------------- #
+    img = np.random.random((4, 4)).astype(np.float32)
+    result = restorer.enhance(img, outscale=2)
+    assert result[0].shape == (8, 8)
+    assert result[1] == 'L'
+
+    # ------------------ test enhance with RGBA---------------- #
+    img = np.random.random((4, 4, 4)).astype(np.float32)
+    result = restorer.enhance(img, outscale=2)
+    assert result[0].shape == (8, 8, 4)
+    assert result[1] == 'RGBA'
+
+    # ------------------ test enhance with RGBA, alpha_upsampler---------------- #
+    restorer.tile_size = 0
+    img = np.random.random((4, 4, 4)).astype(np.float32)
+    result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
+    assert result[0].shape == (8, 8, 4)
+    assert result[1] == 'RGBA'