Browse Source

add unittest for dataset and archs

Xintao 3 years ago
parent
commit
1d180efaf3

+ 1 - 1
realesrgan/data/realesrgan_paired_dataset.py

@@ -59,7 +59,7 @@ class RealESRGANPairedDataset(data.Dataset):
             # disk backend with meta_info
             # disk backend with meta_info
             # Each line in the meta_info describes the relative path to an image
             # Each line in the meta_info describes the relative path to an image
             with open(self.opt['meta_info']) as fin:
             with open(self.opt['meta_info']) as fin:
-                paths = [line.strip().split(' ')[0] for line in fin]
+                paths = [line.strip() for line in fin]
             self.paths = []
             self.paths = []
             for path in paths:
             for path in paths:
                 gt_path, lq_path = path.split(', ')
                 gt_path, lq_path = path.split(', ')

+ 7 - 1
setup.cfg

@@ -17,7 +17,7 @@ line_length = 120
 multi_line_output = 0
 multi_line_output = 0
 known_standard_library = pkg_resources,setuptools
 known_standard_library = pkg_resources,setuptools
 known_first_party = realesrgan
 known_first_party = realesrgan
-known_third_party = PIL,basicsr,cv2,numpy,torch,torchvision,tqdm
+known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
 no_lines_before = STDLIB,LOCALFOLDER
 no_lines_before = STDLIB,LOCALFOLDER
 default_section = THIRDPARTY
 default_section = THIRDPARTY
 
 
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
 skip = .git,./docs/build
 skip = .git,./docs/build
 count =
 count =
 quiet-level = 3
 quiet-level = 3
+
+[aliases]
+test=pytest
+
+[tool:pytest]
+addopts=tests/

+ 28 - 0
tests/data/demo_option_realesrgan_dataset.yml

@@ -0,0 +1,28 @@
+name: Demo
+type: RealESRGANDataset
+dataroot_gt: tests/data/gt
+meta_info: tests/data/meta_info_gt.txt
+io_backend:
+  type: disk
+
+blur_kernel_size: 21
+kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+sinc_prob: 1
+blur_sigma: [0.2, 3]
+betag_range: [0.5, 4]
+betap_range: [1, 2]
+
+blur_kernel_size2: 21
+kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+sinc_prob2: 1
+blur_sigma2: [0.2, 1.5]
+betag_range2: [0.5, 4]
+betap_range2: [1, 2]
+
+final_sinc_prob: 1
+
+gt_size: 128
+use_hflip: True
+use_rot: False

+ 13 - 0
tests/data/demo_option_realesrgan_paired_dataset.yml

@@ -0,0 +1,13 @@
+name: Demo
+type: RealESRGANPairedDataset
+scale: 4
+dataroot_gt: tests/data
+dataroot_lq: tests/data
+meta_info: tests/data/meta_info_pair.txt
+io_backend:
+  type: disk
+
+phase: train
+gt_size: 128
+use_hflip: True
+use_rot: False

BIN
tests/data/gt.lmdb/data.mdb


BIN
tests/data/gt.lmdb/lock.mdb


+ 2 - 0
tests/data/gt.lmdb/meta_info.txt

@@ -0,0 +1,2 @@
+baboon.png (480,500,3) 1
+comic.png (360,240,3) 1

BIN
tests/data/gt/baboon.png


BIN
tests/data/gt/comic.png


BIN
tests/data/lq.lmdb/data.mdb


BIN
tests/data/lq.lmdb/lock.mdb


+ 2 - 0
tests/data/lq.lmdb/meta_info.txt

@@ -0,0 +1,2 @@
+baboon.png (120,125,3) 1
+comic.png (80,60,3) 1

BIN
tests/data/lq/baboon.png


BIN
tests/data/lq/comic.png


+ 2 - 0
tests/data/meta_info_gt.txt

@@ -0,0 +1,2 @@
+baboon.png
+comic.png

+ 2 - 0
tests/data/meta_info_pair.txt

@@ -0,0 +1,2 @@
+gt/baboon.png, lq/baboon.png
+gt/comic.png, lq/comic.png

+ 151 - 0
tests/test_dataset.py

@@ -0,0 +1,151 @@
+import pytest
+import yaml
+
+from realesrgan.data.realesrgan_dataset import RealESRGANDataset
+from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
+
+
+def test_realesrgan_dataset():
+
+    with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
+        opt = yaml.load(f, Loader=yaml.FullLoader)
+
+    dataset = RealESRGANDataset(opt)
+    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
+    assert len(dataset) == 2  # whether to read correct meta info
+    assert dataset.kernel_list == [
+        'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
+    ]  # correct initialization the degradation configurations
+    assert dataset.betag_range2 == [0.5, 4]
+
+    # test __getitem__
+    result = dataset.__getitem__(0)
+    # check returned keys
+    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 400, 400)
+    assert result['kernel1'].shape == (21, 21)
+    assert result['kernel2'].shape == (21, 21)
+    assert result['sinc_kernel'].shape == (21, 21)
+    assert result['gt_path'] == 'tests/data/gt/baboon.png'
+
+    # ------------------ test lmdb backend -------------------- #
+    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+    opt['io_backend']['type'] = 'lmdb'
+
+    dataset = RealESRGANDataset(opt)
+    assert dataset.io_backend_opt['type'] == 'lmdb'  # io backend
+    assert len(dataset.paths) == 2  # whether to read correct meta info
+    assert dataset.kernel_list == [
+        'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
+    ]  # correct initialization the degradation configurations
+    assert dataset.betag_range2 == [0.5, 4]
+
+    # test __getitem__
+    result = dataset.__getitem__(1)
+    # check returned keys
+    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 400, 400)
+    assert result['kernel1'].shape == (21, 21)
+    assert result['kernel2'].shape == (21, 21)
+    assert result['sinc_kernel'].shape == (21, 21)
+    assert result['gt_path'] == 'comic'
+
+    # ------------------ test with sinc_prob = 0 -------------------- #
+    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+    opt['io_backend']['type'] = 'lmdb'
+    opt['sinc_prob'] = 0
+    opt['sinc_prob2'] = 0
+    opt['final_sinc_prob'] = 0
+    dataset = RealESRGANDataset(opt)
+    result = dataset.__getitem__(0)
+    # check returned keys
+    expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 400, 400)
+    assert result['kernel1'].shape == (21, 21)
+    assert result['kernel2'].shape == (21, 21)
+    assert result['sinc_kernel'].shape == (21, 21)
+    assert result['gt_path'] == 'baboon'
+
+    # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
+    with pytest.raises(ValueError):
+        opt['dataroot_gt'] = 'tests/data/gt'
+        opt['io_backend']['type'] = 'lmdb'
+        dataset = RealESRGANDataset(opt)
+
+
+def test_realesrgan_paired_dataset():
+
+    with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
+        opt = yaml.load(f, Loader=yaml.FullLoader)
+
+    dataset = RealESRGANPairedDataset(opt)
+    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
+    assert len(dataset) == 2  # whether to read correct meta info
+
+    # test __getitem__
+    result = dataset.__getitem__(0)
+    # check returned keys
+    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 128, 128)
+    assert result['lq'].shape == (3, 32, 32)
+    assert result['gt_path'] == 'tests/data/gt/baboon.png'
+    assert result['lq_path'] == 'tests/data/lq/baboon.png'
+
+    # ------------------ test lmdb backend -------------------- #
+    opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+    opt['dataroot_lq'] = 'tests/data/lq.lmdb'
+    opt['io_backend']['type'] = 'lmdb'
+
+    dataset = RealESRGANPairedDataset(opt)
+    assert dataset.io_backend_opt['type'] == 'lmdb'  # io backend
+    assert len(dataset) == 2  # whether to read correct meta info
+
+    # test __getitem__
+    result = dataset.__getitem__(1)
+    # check returned keys
+    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 128, 128)
+    assert result['lq'].shape == (3, 32, 32)
+    assert result['gt_path'] == 'comic'
+    assert result['lq_path'] == 'comic'
+
+    # ------------------ test paired_paths_from_folder -------------------- #
+    opt['dataroot_gt'] = 'tests/data/gt'
+    opt['dataroot_lq'] = 'tests/data/lq'
+    opt['io_backend'] = dict(type='disk')
+    opt['meta_info'] = None
+
+    dataset = RealESRGANPairedDataset(opt)
+    assert dataset.io_backend_opt['type'] == 'disk'  # io backend
+    assert len(dataset) == 2  # whether to read correct meta info
+
+    # test __getitem__
+    result = dataset.__getitem__(0)
+    # check returned keys
+    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 128, 128)
+    assert result['lq'].shape == (3, 32, 32)
+
+    # ------------------ test normalization -------------------- #
+    dataset.mean = [0.5, 0.5, 0.5]
+    dataset.std = [0.5, 0.5, 0.5]
+    # test __getitem__
+    result = dataset.__getitem__(0)
+    # check returned keys
+    expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+    assert set(expected_keys).issubset(set(result.keys()))
+    # check shape and contents
+    assert result['gt'].shape == (3, 128, 128)
+    assert result['lq'].shape == (3, 32, 32)

+ 19 - 0
tests/test_discriminator_arch.py

@@ -0,0 +1,19 @@
+import torch
+
+from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
+
+
+def test_unetdiscriminatorsn():
+    """Test arch: UNetDiscriminatorSN."""
+
+    # model init and forward (cpu)
+    net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
+    img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+    output = net(img)
+    assert output.shape == (1, 1, 32, 32)
+
+    # model init and forward (gpu)
+    if torch.cuda.is_available():
+        net.cuda()
+        output = net(img.cuda())
+        assert output.shape == (1, 1, 32, 32)