test_discriminator_arch.py 561 B

12345678910111213141516171819
  1. import torch
  2. from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
  3. def test_unetdiscriminatorsn():
  4. """Test arch: UNetDiscriminatorSN."""
  5. # model init and forward (cpu)
  6. net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
  7. img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
  8. output = net(img)
  9. assert output.shape == (1, 1, 32, 32)
  10. # model init and forward (gpu)
  11. if torch.cuda.is_available():
  12. net.cuda()
  13. output = net(img.cuda())
  14. assert output.shape == (1, 1, 32, 32)