realesrnet_model.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import numpy as np
  2. import random
  3. import torch
  4. from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
  5. from basicsr.data.transforms import paired_random_crop
  6. from basicsr.models.sr_model import SRModel
  7. from basicsr.utils import DiffJPEG, USMSharp
  8. from basicsr.utils.img_process_util import filter2D
  9. from basicsr.utils.registry import MODEL_REGISTRY
  10. from torch.nn import functional as F
  11. @MODEL_REGISTRY.register()
  12. class RealESRNetModel(SRModel):
  13. """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
  14. It is trained without GAN losses.
  15. It mainly performs:
  16. 1. randomly synthesize LQ images in GPU tensors
  17. 2. optimize the networks with GAN training.
  18. """
  19. def __init__(self, opt):
  20. super(RealESRNetModel, self).__init__(opt)
  21. self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
  22. self.usm_sharpener = USMSharp().cuda() # do usm sharpening
  23. self.queue_size = opt.get('queue_size', 180)
  24. @torch.no_grad()
  25. def _dequeue_and_enqueue(self):
  26. """It is the training pair pool for increasing the diversity in a batch.
  27. Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
  28. batch could not have different resize scaling factors. Therefore, we employ this training pair pool
  29. to increase the degradation diversity in a batch.
  30. """
  31. # initialize
  32. b, c, h, w = self.lq.size()
  33. if not hasattr(self, 'queue_lr'):
  34. assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
  35. self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
  36. _, c, h, w = self.gt.size()
  37. self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
  38. self.queue_ptr = 0
  39. if self.queue_ptr == self.queue_size: # the pool is full
  40. # do dequeue and enqueue
  41. # shuffle
  42. idx = torch.randperm(self.queue_size)
  43. self.queue_lr = self.queue_lr[idx]
  44. self.queue_gt = self.queue_gt[idx]
  45. # get first b samples
  46. lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
  47. gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
  48. # update the queue
  49. self.queue_lr[0:b, :, :, :] = self.lq.clone()
  50. self.queue_gt[0:b, :, :, :] = self.gt.clone()
  51. self.lq = lq_dequeue
  52. self.gt = gt_dequeue
  53. else:
  54. # only do enqueue
  55. self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
  56. self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
  57. self.queue_ptr = self.queue_ptr + b
  58. @torch.no_grad()
  59. def feed_data(self, data):
  60. """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
  61. """
  62. if self.is_train and self.opt.get('high_order_degradation', True):
  63. # training data synthesis
  64. self.gt = data['gt'].to(self.device)
  65. # USM sharpen the GT images
  66. if self.opt['gt_usm'] is True:
  67. self.gt = self.usm_sharpener(self.gt)
  68. self.kernel1 = data['kernel1'].to(self.device)
  69. self.kernel2 = data['kernel2'].to(self.device)
  70. self.sinc_kernel = data['sinc_kernel'].to(self.device)
  71. ori_h, ori_w = self.gt.size()[2:4]
  72. # ----------------------- The first degradation process ----------------------- #
  73. # blur
  74. out = filter2D(self.gt, self.kernel1)
  75. # random resize
  76. updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
  77. if updown_type == 'up':
  78. scale = np.random.uniform(1, self.opt['resize_range'][1])
  79. elif updown_type == 'down':
  80. scale = np.random.uniform(self.opt['resize_range'][0], 1)
  81. else:
  82. scale = 1
  83. mode = random.choice(['area', 'bilinear', 'bicubic'])
  84. out = F.interpolate(out, scale_factor=scale, mode=mode)
  85. # add noise
  86. gray_noise_prob = self.opt['gray_noise_prob']
  87. if np.random.uniform() < self.opt['gaussian_noise_prob']:
  88. out = random_add_gaussian_noise_pt(
  89. out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
  90. else:
  91. out = random_add_poisson_noise_pt(
  92. out,
  93. scale_range=self.opt['poisson_scale_range'],
  94. gray_prob=gray_noise_prob,
  95. clip=True,
  96. rounds=False)
  97. # JPEG compression
  98. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
  99. out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
  100. out = self.jpeger(out, quality=jpeg_p)
  101. # ----------------------- The second degradation process ----------------------- #
  102. # blur
  103. if np.random.uniform() < self.opt['second_blur_prob']:
  104. out = filter2D(out, self.kernel2)
  105. # random resize
  106. updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
  107. if updown_type == 'up':
  108. scale = np.random.uniform(1, self.opt['resize_range2'][1])
  109. elif updown_type == 'down':
  110. scale = np.random.uniform(self.opt['resize_range2'][0], 1)
  111. else:
  112. scale = 1
  113. mode = random.choice(['area', 'bilinear', 'bicubic'])
  114. out = F.interpolate(
  115. out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
  116. # add noise
  117. gray_noise_prob = self.opt['gray_noise_prob2']
  118. if np.random.uniform() < self.opt['gaussian_noise_prob2']:
  119. out = random_add_gaussian_noise_pt(
  120. out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
  121. else:
  122. out = random_add_poisson_noise_pt(
  123. out,
  124. scale_range=self.opt['poisson_scale_range2'],
  125. gray_prob=gray_noise_prob,
  126. clip=True,
  127. rounds=False)
  128. # JPEG compression + the final sinc filter
  129. # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
  130. # as one operation.
  131. # We consider two orders:
  132. # 1. [resize back + sinc filter] + JPEG compression
  133. # 2. JPEG compression + [resize back + sinc filter]
  134. # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
  135. if np.random.uniform() < 0.5:
  136. # resize back + the final sinc filter
  137. mode = random.choice(['area', 'bilinear', 'bicubic'])
  138. out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
  139. out = filter2D(out, self.sinc_kernel)
  140. # JPEG compression
  141. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
  142. out = torch.clamp(out, 0, 1)
  143. out = self.jpeger(out, quality=jpeg_p)
  144. else:
  145. # JPEG compression
  146. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
  147. out = torch.clamp(out, 0, 1)
  148. out = self.jpeger(out, quality=jpeg_p)
  149. # resize back + the final sinc filter
  150. mode = random.choice(['area', 'bilinear', 'bicubic'])
  151. out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
  152. out = filter2D(out, self.sinc_kernel)
  153. # clamp and round
  154. self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
  155. # random crop
  156. gt_size = self.opt['gt_size']
  157. self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
  158. # training pair pool
  159. self._dequeue_and_enqueue()
  160. self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
  161. else:
  162. # for paired training or validation
  163. self.lq = data['lq'].to(self.device)
  164. if 'gt' in data:
  165. self.gt = data['gt'].to(self.device)
  166. self.gt_usm = self.usm_sharpener(self.gt)
  167. def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
  168. # do not use the synthetic process during validation
  169. self.is_train = False
  170. super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
  171. self.is_train = True