realesrgan_paired_dataset.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
  3. from basicsr.data.transforms import augment, paired_random_crop
  4. from basicsr.utils import FileClient, imfrombytes, img2tensor
  5. from basicsr.utils.registry import DATASET_REGISTRY
  6. from torch.utils import data as data
  7. from torchvision.transforms.functional import normalize
  8. @DATASET_REGISTRY.register()
  9. class RealESRGANPairedDataset(data.Dataset):
  10. """Paired image dataset for image restoration.
  11. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
  12. There are three modes:
  13. 1. 'lmdb': Use lmdb files.
  14. If opt['io_backend'] == lmdb.
  15. 2. 'meta_info': Use meta information file to generate paths.
  16. If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
  17. 3. 'folder': Scan folders to generate paths.
  18. The rest.
  19. Args:
  20. opt (dict): Config for train datasets. It contains the following keys:
  21. dataroot_gt (str): Data root path for gt.
  22. dataroot_lq (str): Data root path for lq.
  23. meta_info (str): Path for meta information file.
  24. io_backend (dict): IO backend type and other kwarg.
  25. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
  26. Default: '{}'.
  27. gt_size (int): Cropped patched size for gt patches.
  28. use_hflip (bool): Use horizontal flips.
  29. use_rot (bool): Use rotation (use vertical flip and transposing h
  30. and w for implementation).
  31. scale (bool): Scale, which will be added automatically.
  32. phase (str): 'train' or 'val'.
  33. """
  34. def __init__(self, opt):
  35. super(RealESRGANPairedDataset, self).__init__()
  36. self.opt = opt
  37. self.file_client = None
  38. self.io_backend_opt = opt['io_backend']
  39. # mean and std for normalizing the input images
  40. self.mean = opt['mean'] if 'mean' in opt else None
  41. self.std = opt['std'] if 'std' in opt else None
  42. self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
  43. self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
  44. # file client (lmdb io backend)
  45. if self.io_backend_opt['type'] == 'lmdb':
  46. self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
  47. self.io_backend_opt['client_keys'] = ['lq', 'gt']
  48. self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
  49. elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
  50. # disk backend with meta_info
  51. # Each line in the meta_info describes the relative path to an image
  52. with open(self.opt['meta_info']) as fin:
  53. paths = [line.strip().split(' ')[0] for line in fin]
  54. self.paths = []
  55. for path in paths:
  56. gt_path, lq_path = path.split(', ')
  57. gt_path = os.path.join(self.gt_folder, gt_path)
  58. lq_path = os.path.join(self.lq_folder, lq_path)
  59. self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
  60. else:
  61. # disk backend
  62. # it will scan the whole folder to get meta info
  63. # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
  64. self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
  65. def __getitem__(self, index):
  66. if self.file_client is None:
  67. self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
  68. scale = self.opt['scale']
  69. # Load gt and lq images. Dimension order: HWC; channel order: BGR;
  70. # image range: [0, 1], float32.
  71. gt_path = self.paths[index]['gt_path']
  72. img_bytes = self.file_client.get(gt_path, 'gt')
  73. img_gt = imfrombytes(img_bytes, float32=True)
  74. lq_path = self.paths[index]['lq_path']
  75. img_bytes = self.file_client.get(lq_path, 'lq')
  76. img_lq = imfrombytes(img_bytes, float32=True)
  77. # augmentation for training
  78. if self.opt['phase'] == 'train':
  79. gt_size = self.opt['gt_size']
  80. # random crop
  81. img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
  82. # flip, rotation
  83. img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
  84. # BGR to RGB, HWC to CHW, numpy to tensor
  85. img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
  86. # normalize
  87. if self.mean is not None or self.std is not None:
  88. normalize(img_lq, self.mean, self.std, inplace=True)
  89. normalize(img_gt, self.mean, self.std, inplace=True)
  90. return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
  91. def __len__(self):
  92. return len(self.paths)