realesrgan_model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import numpy as np
  2. import random
  3. import torch
  4. from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
  5. from basicsr.data.transforms import paired_random_crop
  6. from basicsr.models.srgan_model import SRGANModel
  7. from basicsr.utils import DiffJPEG, USMSharp
  8. from basicsr.utils.img_process_util import filter2D
  9. from basicsr.utils.registry import MODEL_REGISTRY
  10. from collections import OrderedDict
  11. from torch.nn import functional as F
  12. @MODEL_REGISTRY.register()
  13. class RealESRGANModel(SRGANModel):
  14. """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
  15. It mainly performs:
  16. 1. randomly synthesize LQ images in GPU tensors
  17. 2. optimize the networks with GAN training.
  18. """
  19. def __init__(self, opt):
  20. super(RealESRGANModel, self).__init__(opt)
  21. self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
  22. self.usm_sharpener = USMSharp().cuda() # do usm sharpening
  23. self.queue_size = opt.get('queue_size', 180)
  24. @torch.no_grad()
  25. def _dequeue_and_enqueue(self):
  26. """It is the training pair pool for increasing the diversity in a batch.
  27. Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
  28. batch could not have different resize scaling factors. Therefore, we employ this training pair pool
  29. to increase the degradation diversity in a batch.
  30. """
  31. # initialize
  32. b, c, h, w = self.lq.size()
  33. if not hasattr(self, 'queue_lr'):
  34. assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
  35. self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
  36. _, c, h, w = self.gt.size()
  37. self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
  38. self.queue_ptr = 0
  39. if self.queue_ptr == self.queue_size: # the pool is full
  40. # do dequeue and enqueue
  41. # shuffle
  42. idx = torch.randperm(self.queue_size)
  43. self.queue_lr = self.queue_lr[idx]
  44. self.queue_gt = self.queue_gt[idx]
  45. # get first b samples
  46. lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
  47. gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
  48. # update the queue
  49. self.queue_lr[0:b, :, :, :] = self.lq.clone()
  50. self.queue_gt[0:b, :, :, :] = self.gt.clone()
  51. self.lq = lq_dequeue
  52. self.gt = gt_dequeue
  53. else:
  54. # only do enqueue
  55. self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
  56. self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
  57. self.queue_ptr = self.queue_ptr + b
  58. @torch.no_grad()
  59. def feed_data(self, data):
  60. """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
  61. """
  62. if self.is_train and self.opt.get('high_order_degradation', True):
  63. # training data synthesis
  64. self.gt = data['gt'].to(self.device)
  65. self.gt_usm = self.usm_sharpener(self.gt)
  66. self.kernel1 = data['kernel1'].to(self.device)
  67. self.kernel2 = data['kernel2'].to(self.device)
  68. self.sinc_kernel = data['sinc_kernel'].to(self.device)
  69. ori_h, ori_w = self.gt.size()[2:4]
  70. # ----------------------- The first degradation process ----------------------- #
  71. # blur
  72. out = filter2D(self.gt_usm, self.kernel1)
  73. # random resize
  74. updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
  75. if updown_type == 'up':
  76. scale = np.random.uniform(1, self.opt['resize_range'][1])
  77. elif updown_type == 'down':
  78. scale = np.random.uniform(self.opt['resize_range'][0], 1)
  79. else:
  80. scale = 1
  81. mode = random.choice(['area', 'bilinear', 'bicubic'])
  82. out = F.interpolate(out, scale_factor=scale, mode=mode)
  83. # add noise
  84. gray_noise_prob = self.opt['gray_noise_prob']
  85. if np.random.uniform() < self.opt['gaussian_noise_prob']:
  86. out = random_add_gaussian_noise_pt(
  87. out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
  88. else:
  89. out = random_add_poisson_noise_pt(
  90. out,
  91. scale_range=self.opt['poisson_scale_range'],
  92. gray_prob=gray_noise_prob,
  93. clip=True,
  94. rounds=False)
  95. # JPEG compression
  96. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
  97. out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
  98. out = self.jpeger(out, quality=jpeg_p)
  99. # ----------------------- The second degradation process ----------------------- #
  100. # blur
  101. if np.random.uniform() < self.opt['second_blur_prob']:
  102. out = filter2D(out, self.kernel2)
  103. # random resize
  104. updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
  105. if updown_type == 'up':
  106. scale = np.random.uniform(1, self.opt['resize_range2'][1])
  107. elif updown_type == 'down':
  108. scale = np.random.uniform(self.opt['resize_range2'][0], 1)
  109. else:
  110. scale = 1
  111. mode = random.choice(['area', 'bilinear', 'bicubic'])
  112. out = F.interpolate(
  113. out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
  114. # add noise
  115. gray_noise_prob = self.opt['gray_noise_prob2']
  116. if np.random.uniform() < self.opt['gaussian_noise_prob2']:
  117. out = random_add_gaussian_noise_pt(
  118. out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
  119. else:
  120. out = random_add_poisson_noise_pt(
  121. out,
  122. scale_range=self.opt['poisson_scale_range2'],
  123. gray_prob=gray_noise_prob,
  124. clip=True,
  125. rounds=False)
  126. # JPEG compression + the final sinc filter
  127. # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
  128. # as one operation.
  129. # We consider two orders:
  130. # 1. [resize back + sinc filter] + JPEG compression
  131. # 2. JPEG compression + [resize back + sinc filter]
  132. # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
  133. if np.random.uniform() < 0.5:
  134. # resize back + the final sinc filter
  135. mode = random.choice(['area', 'bilinear', 'bicubic'])
  136. out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
  137. out = filter2D(out, self.sinc_kernel)
  138. # JPEG compression
  139. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
  140. out = torch.clamp(out, 0, 1)
  141. out = self.jpeger(out, quality=jpeg_p)
  142. else:
  143. # JPEG compression
  144. jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
  145. out = torch.clamp(out, 0, 1)
  146. out = self.jpeger(out, quality=jpeg_p)
  147. # resize back + the final sinc filter
  148. mode = random.choice(['area', 'bilinear', 'bicubic'])
  149. out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
  150. out = filter2D(out, self.sinc_kernel)
  151. # clamp and round
  152. self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
  153. # random crop
  154. gt_size = self.opt['gt_size']
  155. (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
  156. self.opt['scale'])
  157. # training pair pool
  158. self._dequeue_and_enqueue()
  159. # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
  160. self.gt_usm = self.usm_sharpener(self.gt)
  161. self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
  162. else:
  163. # for paired training or validation
  164. self.lq = data['lq'].to(self.device)
  165. if 'gt' in data:
  166. self.gt = data['gt'].to(self.device)
  167. self.gt_usm = self.usm_sharpener(self.gt)
  168. def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
  169. # do not use the synthetic process during validation
  170. self.is_train = False
  171. super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
  172. self.is_train = True
  173. def optimize_parameters(self, current_iter):
  174. # usm sharpening
  175. l1_gt = self.gt_usm
  176. percep_gt = self.gt_usm
  177. gan_gt = self.gt_usm
  178. if self.opt['l1_gt_usm'] is False:
  179. l1_gt = self.gt
  180. if self.opt['percep_gt_usm'] is False:
  181. percep_gt = self.gt
  182. if self.opt['gan_gt_usm'] is False:
  183. gan_gt = self.gt
  184. # optimize net_g
  185. for p in self.net_d.parameters():
  186. p.requires_grad = False
  187. self.optimizer_g.zero_grad()
  188. self.output = self.net_g(self.lq)
  189. l_g_total = 0
  190. loss_dict = OrderedDict()
  191. if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
  192. # pixel loss
  193. if self.cri_pix:
  194. l_g_pix = self.cri_pix(self.output, l1_gt)
  195. l_g_total += l_g_pix
  196. loss_dict['l_g_pix'] = l_g_pix
  197. # perceptual loss
  198. if self.cri_perceptual:
  199. l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
  200. if l_g_percep is not None:
  201. l_g_total += l_g_percep
  202. loss_dict['l_g_percep'] = l_g_percep
  203. if l_g_style is not None:
  204. l_g_total += l_g_style
  205. loss_dict['l_g_style'] = l_g_style
  206. # gan loss
  207. fake_g_pred = self.net_d(self.output)
  208. l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
  209. l_g_total += l_g_gan
  210. loss_dict['l_g_gan'] = l_g_gan
  211. l_g_total.backward()
  212. self.optimizer_g.step()
  213. # optimize net_d
  214. for p in self.net_d.parameters():
  215. p.requires_grad = True
  216. self.optimizer_d.zero_grad()
  217. # real
  218. real_d_pred = self.net_d(gan_gt)
  219. l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
  220. loss_dict['l_d_real'] = l_d_real
  221. loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
  222. l_d_real.backward()
  223. # fake
  224. fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
  225. l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
  226. loss_dict['l_d_fake'] = l_d_fake
  227. loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
  228. l_d_fake.backward()
  229. self.optimizer_d.step()
  230. if self.ema_decay > 0:
  231. self.model_ema(decay=self.ema_decay)
  232. self.log_dict = self.reduce_loss_dict(loss_dict)