srvgg_arch.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from basicsr.utils.registry import ARCH_REGISTRY
  2. from torch import nn as nn
  3. from torch.nn import functional as F
  4. @ARCH_REGISTRY.register()
  5. class SRVGGNetCompact(nn.Module):
  6. """A compact VGG-style network structure for super-resolution.
  7. It is a compact network structure, which performs upsampling in the last layer and no convolution is
  8. conducted on the HR feature space.
  9. Args:
  10. num_in_ch (int): Channel number of inputs. Default: 3.
  11. num_out_ch (int): Channel number of outputs. Default: 3.
  12. num_feat (int): Channel number of intermediate features. Default: 64.
  13. num_conv (int): Number of convolution layers in the body network. Default: 16.
  14. upscale (int): Upsampling factor. Default: 4.
  15. act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
  16. """
  17. def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
  18. super(SRVGGNetCompact, self).__init__()
  19. self.num_in_ch = num_in_ch
  20. self.num_out_ch = num_out_ch
  21. self.num_feat = num_feat
  22. self.num_conv = num_conv
  23. self.upscale = upscale
  24. self.act_type = act_type
  25. self.body = nn.ModuleList()
  26. # the first conv
  27. self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
  28. # the first activation
  29. if act_type == 'relu':
  30. activation = nn.ReLU(inplace=True)
  31. elif act_type == 'prelu':
  32. activation = nn.PReLU(num_parameters=num_feat)
  33. elif act_type == 'leakyrelu':
  34. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  35. self.body.append(activation)
  36. # the body structure
  37. for _ in range(num_conv):
  38. self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
  39. # activation
  40. if act_type == 'relu':
  41. activation = nn.ReLU(inplace=True)
  42. elif act_type == 'prelu':
  43. activation = nn.PReLU(num_parameters=num_feat)
  44. elif act_type == 'leakyrelu':
  45. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  46. self.body.append(activation)
  47. # the last conv
  48. self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
  49. # upsample
  50. self.upsampler = nn.PixelShuffle(upscale)
  51. def forward(self, x):
  52. out = x
  53. for i in range(0, len(self.body)):
  54. out = self.body[i](out)
  55. out = self.upsampler(out)
  56. # add the nearest upsampled image, so that the network learns the residual
  57. base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
  58. out += base
  59. return out