|
@@ -13,35 +13,45 @@ from torch.nn import functional as F
|
|
|
|
|
|
@MODEL_REGISTRY.register()
|
|
|
class RealESRGANModel(SRGANModel):
|
|
|
- """RealESRGAN Model"""
|
|
|
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
|
+
|
|
|
+ It mainly performs:
|
|
|
+ 1. randomly synthesize LQ images in GPU tensors
|
|
|
+ 2. optimize the networks with GAN training.
|
|
|
+ """
|
|
|
|
|
|
def __init__(self, opt):
|
|
|
super(RealESRGANModel, self).__init__(opt)
|
|
|
- self.jpeger = DiffJPEG(differentiable=False).cuda()
|
|
|
- self.usm_sharpener = USMSharp().cuda()
|
|
|
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
|
|
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
|
|
self.queue_size = opt.get('queue_size', 180)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _dequeue_and_enqueue(self):
|
|
|
- # training pair pool
|
|
|
+ """It is the training pair pool for increasing the diversity in a batch.
|
|
|
+
|
|
|
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
|
|
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
|
|
+ to increase the degradation diversity in a batch.
|
|
|
+ """
|
|
|
# initialize
|
|
|
b, c, h, w = self.lq.size()
|
|
|
if not hasattr(self, 'queue_lr'):
|
|
|
- assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
|
|
|
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
|
|
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
|
_, c, h, w = self.gt.size()
|
|
|
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
|
self.queue_ptr = 0
|
|
|
- if self.queue_ptr == self.queue_size: # full
|
|
|
+ if self.queue_ptr == self.queue_size: # the pool is full
|
|
|
# do dequeue and enqueue
|
|
|
# shuffle
|
|
|
idx = torch.randperm(self.queue_size)
|
|
|
self.queue_lr = self.queue_lr[idx]
|
|
|
self.queue_gt = self.queue_gt[idx]
|
|
|
- # get
|
|
|
+ # get first b samples
|
|
|
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
|
|
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
|
|
- # update
|
|
|
+ # update the queue
|
|
|
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
|
|
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
|
|
|
|
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def feed_data(self, data):
|
|
|
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
|
|
+ """
|
|
|
if self.is_train and self.opt.get('high_order_degradation', True):
|
|
|
# training data synthesis
|
|
|
self.gt = data['gt'].to(self.device)
|
|
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
|
|
|
scale = 1
|
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
|
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
|
|
- # noise
|
|
|
+ # add noise
|
|
|
gray_noise_prob = self.opt['gray_noise_prob']
|
|
|
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
|
|
out = random_add_gaussian_noise_pt(
|
|
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
|
|
|
rounds=False)
|
|
|
# JPEG compression
|
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
|
|
- out = torch.clamp(out, 0, 1)
|
|
|
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
|
|
out = self.jpeger(out, quality=jpeg_p)
|
|
|
|
|
|
# ----------------------- The second degradation process ----------------------- #
|
|
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
|
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
|
out = F.interpolate(
|
|
|
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
|
|
- # noise
|
|
|
+ # add noise
|
|
|
gray_noise_prob = self.opt['gray_noise_prob2']
|
|
|
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
|
|
out = random_add_gaussian_noise_pt(
|
|
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
|
|
|
self._dequeue_and_enqueue()
|
|
|
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
|
|
self.gt_usm = self.usm_sharpener(self.gt)
|
|
|
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
|
|
else:
|
|
|
+ # for paired training or validation
|
|
|
self.lq = data['lq'].to(self.device)
|
|
|
if 'gt' in data:
|
|
|
self.gt = data['gt'].to(self.device)
|
|
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
|
|
|
self.is_train = True
|
|
|
|
|
|
def optimize_parameters(self, current_iter):
|
|
|
+ # usm sharpening
|
|
|
l1_gt = self.gt_usm
|
|
|
percep_gt = self.gt_usm
|
|
|
gan_gt = self.gt_usm
|