123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import torch
- import yaml
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from basicsr.data.paired_image_dataset import PairedImageDataset
- from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
- from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
- from realesrgan.models.realesrgan_model import RealESRGANModel
- from realesrgan.models.realesrnet_model import RealESRNetModel
- def test_realesrnet_model():
- with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
- opt = yaml.load(f, Loader=yaml.FullLoader)
- # build model
- model = RealESRNetModel(opt)
- # test attributes
- assert model.__class__.__name__ == 'RealESRNetModel'
- assert isinstance(model.net_g, RRDBNet)
- assert isinstance(model.cri_pix, L1Loss)
- assert isinstance(model.optimizers[0], torch.optim.Adam)
- # prepare data
- gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
- kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
- kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
- sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
- data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
- model.feed_data(data)
- # check dequeue
- model.feed_data(data)
- # check data shape
- assert model.lq.shape == (1, 3, 8, 8)
- assert model.gt.shape == (1, 3, 32, 32)
- # change probability to test if-else
- model.opt['gaussian_noise_prob'] = 0
- model.opt['gray_noise_prob'] = 0
- model.opt['second_blur_prob'] = 0
- model.opt['gaussian_noise_prob2'] = 0
- model.opt['gray_noise_prob2'] = 0
- model.feed_data(data)
- # check data shape
- assert model.lq.shape == (1, 3, 8, 8)
- assert model.gt.shape == (1, 3, 32, 32)
- # ----------------- test nondist_validation -------------------- #
- # construct dataloader
- dataset_opt = dict(
- name='Demo',
- dataroot_gt='tests/data/gt',
- dataroot_lq='tests/data/lq',
- io_backend=dict(type='disk'),
- scale=4,
- phase='val')
- dataset = PairedImageDataset(dataset_opt)
- dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
- assert model.is_train is True
- model.nondist_validation(dataloader, 1, None, False)
- assert model.is_train is True
- def test_realesrgan_model():
- with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
- opt = yaml.load(f, Loader=yaml.FullLoader)
- # build model
- model = RealESRGANModel(opt)
- # test attributes
- assert model.__class__.__name__ == 'RealESRGANModel'
- assert isinstance(model.net_g, RRDBNet) # generator
- assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
- assert isinstance(model.cri_pix, L1Loss)
- assert isinstance(model.cri_perceptual, PerceptualLoss)
- assert isinstance(model.cri_gan, GANLoss)
- assert isinstance(model.optimizers[0], torch.optim.Adam)
- assert isinstance(model.optimizers[1], torch.optim.Adam)
- # prepare data
- gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
- kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
- kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
- sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
- data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
- model.feed_data(data)
- # check dequeue
- model.feed_data(data)
- # check data shape
- assert model.lq.shape == (1, 3, 8, 8)
- assert model.gt.shape == (1, 3, 32, 32)
- # change probability to test if-else
- model.opt['gaussian_noise_prob'] = 0
- model.opt['gray_noise_prob'] = 0
- model.opt['second_blur_prob'] = 0
- model.opt['gaussian_noise_prob2'] = 0
- model.opt['gray_noise_prob2'] = 0
- model.feed_data(data)
- # check data shape
- assert model.lq.shape == (1, 3, 8, 8)
- assert model.gt.shape == (1, 3, 32, 32)
- # ----------------- test nondist_validation -------------------- #
- # construct dataloader
- dataset_opt = dict(
- name='Demo',
- dataroot_gt='tests/data/gt',
- dataroot_lq='tests/data/lq',
- io_backend=dict(type='disk'),
- scale=4,
- phase='val')
- dataset = PairedImageDataset(dataset_opt)
- dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
- assert model.is_train is True
- model.nondist_validation(dataloader, 1, None, False)
- assert model.is_train is True
- # ----------------- test optimize_parameters -------------------- #
- model.feed_data(data)
- model.optimize_parameters(1)
- assert model.output.shape == (1, 3, 32, 32)
- assert isinstance(model.log_dict, dict)
- # check returned keys
- expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
- assert set(expected_keys).issubset(set(model.log_dict.keys()))
|