realesrgan_dataset.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import cv2
  2. import math
  3. import numpy as np
  4. import os
  5. import os.path as osp
  6. import random
  7. import time
  8. import torch
  9. from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
  10. from basicsr.data.transforms import augment
  11. from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
  12. from basicsr.utils.registry import DATASET_REGISTRY
  13. from torch.utils import data as data
  14. @DATASET_REGISTRY.register()
  15. class RealESRGANDataset(data.Dataset):
  16. """Dataset used for Real-ESRGAN model:
  17. Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
  18. It loads gt (Ground-Truth) images, and augments them.
  19. It also generates blur kernels and sinc kernels for generating low-quality images.
  20. Note that the low-quality images are processed in tensors on GPUS for faster processing.
  21. Args:
  22. opt (dict): Config for train datasets. It contains the following keys:
  23. dataroot_gt (str): Data root path for gt.
  24. meta_info (str): Path for meta information file.
  25. io_backend (dict): IO backend type and other kwarg.
  26. use_hflip (bool): Use horizontal flips.
  27. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
  28. Please see more options in the codes.
  29. """
  30. def __init__(self, opt):
  31. super(RealESRGANDataset, self).__init__()
  32. self.opt = opt
  33. self.file_client = None
  34. self.io_backend_opt = opt['io_backend']
  35. self.gt_folder = opt['dataroot_gt']
  36. # file client (lmdb io backend)
  37. if self.io_backend_opt['type'] == 'lmdb':
  38. self.io_backend_opt['db_paths'] = [self.gt_folder]
  39. self.io_backend_opt['client_keys'] = ['gt']
  40. if not self.gt_folder.endswith('.lmdb'):
  41. raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
  42. with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
  43. self.paths = [line.split('.')[0] for line in fin]
  44. else:
  45. # disk backend with meta_info
  46. # Each line in the meta_info describes the relative path to an image
  47. with open(self.opt['meta_info']) as fin:
  48. paths = [line.strip().split(' ')[0] for line in fin]
  49. self.paths = [os.path.join(self.gt_folder, v) for v in paths]
  50. # blur settings for the first degradation
  51. self.blur_kernel_size = opt['blur_kernel_size']
  52. self.kernel_list = opt['kernel_list']
  53. self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
  54. self.blur_sigma = opt['blur_sigma']
  55. self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
  56. self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
  57. self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
  58. # blur settings for the second degradation
  59. self.blur_kernel_size2 = opt['blur_kernel_size2']
  60. self.kernel_list2 = opt['kernel_list2']
  61. self.kernel_prob2 = opt['kernel_prob2']
  62. self.blur_sigma2 = opt['blur_sigma2']
  63. self.betag_range2 = opt['betag_range2']
  64. self.betap_range2 = opt['betap_range2']
  65. self.sinc_prob2 = opt['sinc_prob2']
  66. # a final sinc filter
  67. self.final_sinc_prob = opt['final_sinc_prob']
  68. self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
  69. # TODO: kernel range is now hard-coded, should be in the configure file
  70. self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
  71. self.pulse_tensor[10, 10] = 1
  72. def __getitem__(self, index):
  73. if self.file_client is None:
  74. self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
  75. # -------------------------------- Load gt images -------------------------------- #
  76. # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
  77. gt_path = self.paths[index]
  78. # avoid errors caused by high latency in reading files
  79. retry = 3
  80. while retry > 0:
  81. try:
  82. img_bytes = self.file_client.get(gt_path, 'gt')
  83. except (IOError, OSError) as e:
  84. logger = get_root_logger()
  85. logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
  86. # change another file to read
  87. index = random.randint(0, self.__len__())
  88. gt_path = self.paths[index]
  89. time.sleep(1) # sleep 1s for occasional server congestion
  90. else:
  91. break
  92. finally:
  93. retry -= 1
  94. img_gt = imfrombytes(img_bytes, float32=True)
  95. # -------------------- Do augmentation for training: flip, rotation -------------------- #
  96. img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
  97. # crop or pad to 400
  98. # TODO: 400 is hard-coded. You may change it accordingly
  99. h, w = img_gt.shape[0:2]
  100. crop_pad_size = 400
  101. # pad
  102. if h < crop_pad_size or w < crop_pad_size:
  103. pad_h = max(0, crop_pad_size - h)
  104. pad_w = max(0, crop_pad_size - w)
  105. img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
  106. # crop
  107. if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
  108. h, w = img_gt.shape[0:2]
  109. # randomly choose top and left coordinates
  110. top = random.randint(0, h - crop_pad_size)
  111. left = random.randint(0, w - crop_pad_size)
  112. img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
  113. # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
  114. kernel_size = random.choice(self.kernel_range)
  115. if np.random.uniform() < self.opt['sinc_prob']:
  116. # this sinc filter setting is for kernels ranging from [7, 21]
  117. if kernel_size < 13:
  118. omega_c = np.random.uniform(np.pi / 3, np.pi)
  119. else:
  120. omega_c = np.random.uniform(np.pi / 5, np.pi)
  121. kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
  122. else:
  123. kernel = random_mixed_kernels(
  124. self.kernel_list,
  125. self.kernel_prob,
  126. kernel_size,
  127. self.blur_sigma,
  128. self.blur_sigma, [-math.pi, math.pi],
  129. self.betag_range,
  130. self.betap_range,
  131. noise_range=None)
  132. # pad kernel
  133. pad_size = (21 - kernel_size) // 2
  134. kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
  135. # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
  136. kernel_size = random.choice(self.kernel_range)
  137. if np.random.uniform() < self.opt['sinc_prob2']:
  138. if kernel_size < 13:
  139. omega_c = np.random.uniform(np.pi / 3, np.pi)
  140. else:
  141. omega_c = np.random.uniform(np.pi / 5, np.pi)
  142. kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
  143. else:
  144. kernel2 = random_mixed_kernels(
  145. self.kernel_list2,
  146. self.kernel_prob2,
  147. kernel_size,
  148. self.blur_sigma2,
  149. self.blur_sigma2, [-math.pi, math.pi],
  150. self.betag_range2,
  151. self.betap_range2,
  152. noise_range=None)
  153. # pad kernel
  154. pad_size = (21 - kernel_size) // 2
  155. kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
  156. # ------------------------------------- the final sinc kernel ------------------------------------- #
  157. if np.random.uniform() < self.opt['final_sinc_prob']:
  158. kernel_size = random.choice(self.kernel_range)
  159. omega_c = np.random.uniform(np.pi / 3, np.pi)
  160. sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
  161. sinc_kernel = torch.FloatTensor(sinc_kernel)
  162. else:
  163. sinc_kernel = self.pulse_tensor
  164. # BGR to RGB, HWC to CHW, numpy to tensor
  165. img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
  166. kernel = torch.FloatTensor(kernel)
  167. kernel2 = torch.FloatTensor(kernel2)
  168. return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
  169. return return_d
  170. def __len__(self):
  171. return len(self.paths)