test_dataset.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import pytest
  2. import yaml
  3. from realesrgan.data.realesrgan_dataset import RealESRGANDataset
  4. from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
  5. def test_realesrgan_dataset():
  6. with open('tests/data/demo_option_realesrgan_dataset.yml', mode='r') as f:
  7. opt = yaml.load(f, Loader=yaml.FullLoader)
  8. dataset = RealESRGANDataset(opt)
  9. assert dataset.io_backend_opt['type'] == 'disk' # io backend
  10. assert len(dataset) == 2 # whether to read correct meta info
  11. assert dataset.kernel_list == [
  12. 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
  13. ] # correct initialization the degradation configurations
  14. assert dataset.betag_range2 == [0.5, 4]
  15. # test __getitem__
  16. result = dataset.__getitem__(0)
  17. # check returned keys
  18. expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
  19. assert set(expected_keys).issubset(set(result.keys()))
  20. # check shape and contents
  21. assert result['gt'].shape == (3, 400, 400)
  22. assert result['kernel1'].shape == (21, 21)
  23. assert result['kernel2'].shape == (21, 21)
  24. assert result['sinc_kernel'].shape == (21, 21)
  25. assert result['gt_path'] == 'tests/data/gt/baboon.png'
  26. # ------------------ test lmdb backend -------------------- #
  27. opt['dataroot_gt'] = 'tests/data/gt.lmdb'
  28. opt['io_backend']['type'] = 'lmdb'
  29. dataset = RealESRGANDataset(opt)
  30. assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
  31. assert len(dataset.paths) == 2 # whether to read correct meta info
  32. assert dataset.kernel_list == [
  33. 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
  34. ] # correct initialization the degradation configurations
  35. assert dataset.betag_range2 == [0.5, 4]
  36. # test __getitem__
  37. result = dataset.__getitem__(1)
  38. # check returned keys
  39. expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
  40. assert set(expected_keys).issubset(set(result.keys()))
  41. # check shape and contents
  42. assert result['gt'].shape == (3, 400, 400)
  43. assert result['kernel1'].shape == (21, 21)
  44. assert result['kernel2'].shape == (21, 21)
  45. assert result['sinc_kernel'].shape == (21, 21)
  46. assert result['gt_path'] == 'comic'
  47. # ------------------ test with sinc_prob = 0 -------------------- #
  48. opt['dataroot_gt'] = 'tests/data/gt.lmdb'
  49. opt['io_backend']['type'] = 'lmdb'
  50. opt['sinc_prob'] = 0
  51. opt['sinc_prob2'] = 0
  52. opt['final_sinc_prob'] = 0
  53. dataset = RealESRGANDataset(opt)
  54. result = dataset.__getitem__(0)
  55. # check returned keys
  56. expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
  57. assert set(expected_keys).issubset(set(result.keys()))
  58. # check shape and contents
  59. assert result['gt'].shape == (3, 400, 400)
  60. assert result['kernel1'].shape == (21, 21)
  61. assert result['kernel2'].shape == (21, 21)
  62. assert result['sinc_kernel'].shape == (21, 21)
  63. assert result['gt_path'] == 'baboon'
  64. # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
  65. with pytest.raises(ValueError):
  66. opt['dataroot_gt'] = 'tests/data/gt'
  67. opt['io_backend']['type'] = 'lmdb'
  68. dataset = RealESRGANDataset(opt)
  69. def test_realesrgan_paired_dataset():
  70. with open('tests/data/demo_option_realesrgan_paired_dataset.yml', mode='r') as f:
  71. opt = yaml.load(f, Loader=yaml.FullLoader)
  72. dataset = RealESRGANPairedDataset(opt)
  73. assert dataset.io_backend_opt['type'] == 'disk' # io backend
  74. assert len(dataset) == 2 # whether to read correct meta info
  75. # test __getitem__
  76. result = dataset.__getitem__(0)
  77. # check returned keys
  78. expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
  79. assert set(expected_keys).issubset(set(result.keys()))
  80. # check shape and contents
  81. assert result['gt'].shape == (3, 128, 128)
  82. assert result['lq'].shape == (3, 32, 32)
  83. assert result['gt_path'] == 'tests/data/gt/baboon.png'
  84. assert result['lq_path'] == 'tests/data/lq/baboon.png'
  85. # ------------------ test lmdb backend -------------------- #
  86. opt['dataroot_gt'] = 'tests/data/gt.lmdb'
  87. opt['dataroot_lq'] = 'tests/data/lq.lmdb'
  88. opt['io_backend']['type'] = 'lmdb'
  89. dataset = RealESRGANPairedDataset(opt)
  90. assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
  91. assert len(dataset) == 2 # whether to read correct meta info
  92. # test __getitem__
  93. result = dataset.__getitem__(1)
  94. # check returned keys
  95. expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
  96. assert set(expected_keys).issubset(set(result.keys()))
  97. # check shape and contents
  98. assert result['gt'].shape == (3, 128, 128)
  99. assert result['lq'].shape == (3, 32, 32)
  100. assert result['gt_path'] == 'comic'
  101. assert result['lq_path'] == 'comic'
  102. # ------------------ test paired_paths_from_folder -------------------- #
  103. opt['dataroot_gt'] = 'tests/data/gt'
  104. opt['dataroot_lq'] = 'tests/data/lq'
  105. opt['io_backend'] = dict(type='disk')
  106. opt['meta_info'] = None
  107. dataset = RealESRGANPairedDataset(opt)
  108. assert dataset.io_backend_opt['type'] == 'disk' # io backend
  109. assert len(dataset) == 2 # whether to read correct meta info
  110. # test __getitem__
  111. result = dataset.__getitem__(0)
  112. # check returned keys
  113. expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
  114. assert set(expected_keys).issubset(set(result.keys()))
  115. # check shape and contents
  116. assert result['gt'].shape == (3, 128, 128)
  117. assert result['lq'].shape == (3, 32, 32)
  118. # ------------------ test normalization -------------------- #
  119. dataset.mean = [0.5, 0.5, 0.5]
  120. dataset.std = [0.5, 0.5, 0.5]
  121. # test __getitem__
  122. result = dataset.__getitem__(0)
  123. # check returned keys
  124. expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
  125. assert set(expected_keys).issubset(set(result.keys()))
  126. # check shape and contents
  127. assert result['gt'].shape == (3, 128, 128)
  128. assert result['lq'].shape == (3, 32, 32)