test_model.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. import yaml
  3. from basicsr.archs.rrdbnet_arch import RRDBNet
  4. from basicsr.data.paired_image_dataset import PairedImageDataset
  5. from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
  6. from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
  7. from realesrgan.models.realesrgan_model import RealESRGANModel
  8. from realesrgan.models.realesrnet_model import RealESRNetModel
  9. def test_realesrnet_model():
  10. with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
  11. opt = yaml.load(f, Loader=yaml.FullLoader)
  12. # build model
  13. model = RealESRNetModel(opt)
  14. # test attributes
  15. assert model.__class__.__name__ == 'RealESRNetModel'
  16. assert isinstance(model.net_g, RRDBNet)
  17. assert isinstance(model.cri_pix, L1Loss)
  18. assert isinstance(model.optimizers[0], torch.optim.Adam)
  19. # prepare data
  20. gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
  21. kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
  22. kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
  23. sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
  24. data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
  25. model.feed_data(data)
  26. # check dequeue
  27. model.feed_data(data)
  28. # check data shape
  29. assert model.lq.shape == (1, 3, 8, 8)
  30. assert model.gt.shape == (1, 3, 32, 32)
  31. # change probability to test if-else
  32. model.opt['gaussian_noise_prob'] = 0
  33. model.opt['gray_noise_prob'] = 0
  34. model.opt['second_blur_prob'] = 0
  35. model.opt['gaussian_noise_prob2'] = 0
  36. model.opt['gray_noise_prob2'] = 0
  37. model.feed_data(data)
  38. # check data shape
  39. assert model.lq.shape == (1, 3, 8, 8)
  40. assert model.gt.shape == (1, 3, 32, 32)
  41. # ----------------- test nondist_validation -------------------- #
  42. # construct dataloader
  43. dataset_opt = dict(
  44. name='Demo',
  45. dataroot_gt='tests/data/gt',
  46. dataroot_lq='tests/data/lq',
  47. io_backend=dict(type='disk'),
  48. scale=4,
  49. phase='val')
  50. dataset = PairedImageDataset(dataset_opt)
  51. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
  52. assert model.is_train is True
  53. model.nondist_validation(dataloader, 1, None, False)
  54. assert model.is_train is True
  55. def test_realesrgan_model():
  56. with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
  57. opt = yaml.load(f, Loader=yaml.FullLoader)
  58. # build model
  59. model = RealESRGANModel(opt)
  60. # test attributes
  61. assert model.__class__.__name__ == 'RealESRGANModel'
  62. assert isinstance(model.net_g, RRDBNet) # generator
  63. assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
  64. assert isinstance(model.cri_pix, L1Loss)
  65. assert isinstance(model.cri_perceptual, PerceptualLoss)
  66. assert isinstance(model.cri_gan, GANLoss)
  67. assert isinstance(model.optimizers[0], torch.optim.Adam)
  68. assert isinstance(model.optimizers[1], torch.optim.Adam)
  69. # prepare data
  70. gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
  71. kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
  72. kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
  73. sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
  74. data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
  75. model.feed_data(data)
  76. # check dequeue
  77. model.feed_data(data)
  78. # check data shape
  79. assert model.lq.shape == (1, 3, 8, 8)
  80. assert model.gt.shape == (1, 3, 32, 32)
  81. # change probability to test if-else
  82. model.opt['gaussian_noise_prob'] = 0
  83. model.opt['gray_noise_prob'] = 0
  84. model.opt['second_blur_prob'] = 0
  85. model.opt['gaussian_noise_prob2'] = 0
  86. model.opt['gray_noise_prob2'] = 0
  87. model.feed_data(data)
  88. # check data shape
  89. assert model.lq.shape == (1, 3, 8, 8)
  90. assert model.gt.shape == (1, 3, 32, 32)
  91. # ----------------- test nondist_validation -------------------- #
  92. # construct dataloader
  93. dataset_opt = dict(
  94. name='Demo',
  95. dataroot_gt='tests/data/gt',
  96. dataroot_lq='tests/data/lq',
  97. io_backend=dict(type='disk'),
  98. scale=4,
  99. phase='val')
  100. dataset = PairedImageDataset(dataset_opt)
  101. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
  102. assert model.is_train is True
  103. model.nondist_validation(dataloader, 1, None, False)
  104. assert model.is_train is True
  105. # ----------------- test optimize_parameters -------------------- #
  106. model.feed_data(data)
  107. model.optimize_parameters(1)
  108. assert model.output.shape == (1, 3, 32, 32)
  109. assert isinstance(model.log_dict, dict)
  110. # check returned keys
  111. expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
  112. assert set(expected_keys).issubset(set(model.log_dict.keys()))