test_utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import numpy as np
  2. from basicsr.archs.rrdbnet_arch import RRDBNet
  3. from realesrgan.utils import RealESRGANer
  4. def test_realesrganer():
  5. # initialize with default model
  6. restorer = RealESRGANer(
  7. scale=4,
  8. model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
  9. model=None,
  10. tile=10,
  11. tile_pad=10,
  12. pre_pad=2,
  13. half=False)
  14. assert isinstance(restorer.model, RRDBNet)
  15. assert restorer.half is False
  16. # initialize with user-defined model
  17. model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
  18. restorer = RealESRGANer(
  19. scale=4,
  20. model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
  21. model=model,
  22. tile=10,
  23. tile_pad=10,
  24. pre_pad=2,
  25. half=True)
  26. # test attribute
  27. assert isinstance(restorer.model, RRDBNet)
  28. assert restorer.half is True
  29. # ------------------ test pre_process ---------------- #
  30. img = np.random.random((12, 12, 3)).astype(np.float32)
  31. restorer.pre_process(img)
  32. assert restorer.img.shape == (1, 3, 14, 14)
  33. # with modcrop
  34. restorer.scale = 1
  35. restorer.pre_process(img)
  36. assert restorer.img.shape == (1, 3, 16, 16)
  37. # ------------------ test process ---------------- #
  38. restorer.process()
  39. assert restorer.output.shape == (1, 3, 64, 64)
  40. # ------------------ test post_process ---------------- #
  41. restorer.mod_scale = 4
  42. output = restorer.post_process()
  43. assert output.shape == (1, 3, 60, 60)
  44. # ------------------ test tile_process ---------------- #
  45. restorer.scale = 4
  46. img = np.random.random((12, 12, 3)).astype(np.float32)
  47. restorer.pre_process(img)
  48. restorer.tile_process()
  49. assert restorer.output.shape == (1, 3, 64, 64)
  50. # ------------------ test enhance ---------------- #
  51. img = np.random.random((12, 12, 3)).astype(np.float32)
  52. result = restorer.enhance(img, outscale=2)
  53. assert result[0].shape == (24, 24, 3)
  54. assert result[1] == 'RGB'
  55. # ------------------ test enhance with 16-bit image---------------- #
  56. img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
  57. result = restorer.enhance(img, outscale=2)
  58. assert result[0].shape == (8, 8, 3)
  59. assert result[1] == 'RGB'
  60. # ------------------ test enhance with gray image---------------- #
  61. img = np.random.random((4, 4)).astype(np.float32)
  62. result = restorer.enhance(img, outscale=2)
  63. assert result[0].shape == (8, 8)
  64. assert result[1] == 'L'
  65. # ------------------ test enhance with RGBA---------------- #
  66. img = np.random.random((4, 4, 4)).astype(np.float32)
  67. result = restorer.enhance(img, outscale=2)
  68. assert result[0].shape == (8, 8, 4)
  69. assert result[1] == 'RGBA'
  70. # ------------------ test enhance with RGBA, alpha_upsampler---------------- #
  71. restorer.tile_size = 0
  72. img = np.random.random((4, 4, 4)).astype(np.float32)
  73. result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
  74. assert result[0].shape == (8, 8, 4)
  75. assert result[1] == 'RGBA'