extract_subimages.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import argparse
  2. import cv2
  3. import numpy as np
  4. import os
  5. import sys
  6. from basicsr.utils import scandir
  7. from multiprocessing import Pool
  8. from os import path as osp
  9. from tqdm import tqdm
  10. def main(args):
  11. """A multi-thread tool to crop large images to sub-images for faster IO.
  12. opt (dict): Configuration dict. It contains:
  13. n_thread (int): Thread number.
  14. compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
  15. A higher value means a smaller size and longer compression time.
  16. Use 0 for faster CPU decompression. Default: 3, same in cv2.
  17. input_folder (str): Path to the input folder.
  18. save_folder (str): Path to save folder.
  19. crop_size (int): Crop size.
  20. step (int): Step for overlapped sliding window.
  21. thresh_size (int): Threshold size. Patches whose size is lower
  22. than thresh_size will be dropped.
  23. Usage:
  24. For each folder, run this script.
  25. Typically, there are four folders to be processed for DIV2K dataset.
  26. DIV2K_train_HR
  27. DIV2K_train_LR_bicubic/X2
  28. DIV2K_train_LR_bicubic/X3
  29. DIV2K_train_LR_bicubic/X4
  30. After process, each sub_folder should have the same number of
  31. subimages.
  32. Remember to modify opt configurations according to your settings.
  33. """
  34. opt = {}
  35. opt['n_thread'] = args.n_thread
  36. opt['compression_level'] = args.compression_level
  37. # HR images
  38. opt['input_folder'] = args.input
  39. opt['save_folder'] = args.output
  40. opt['crop_size'] = args.crop_size
  41. opt['step'] = args.step
  42. opt['thresh_size'] = args.thresh_size
  43. extract_subimages(opt)
  44. def extract_subimages(opt):
  45. """Crop images to subimages.
  46. Args:
  47. opt (dict): Configuration dict. It contains:
  48. input_folder (str): Path to the input folder.
  49. save_folder (str): Path to save folder.
  50. n_thread (int): Thread number.
  51. """
  52. input_folder = opt['input_folder']
  53. save_folder = opt['save_folder']
  54. if not osp.exists(save_folder):
  55. os.makedirs(save_folder)
  56. print(f'mkdir {save_folder} ...')
  57. else:
  58. print(f'Folder {save_folder} already exists. Exit.')
  59. sys.exit(1)
  60. img_list = list(scandir(input_folder, full_path=True))
  61. pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
  62. pool = Pool(opt['n_thread'])
  63. for path in img_list:
  64. pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
  65. pool.close()
  66. pool.join()
  67. pbar.close()
  68. print('All processes done.')
  69. def worker(path, opt):
  70. """Worker for each process.
  71. Args:
  72. path (str): Image path.
  73. opt (dict): Configuration dict. It contains:
  74. crop_size (int): Crop size.
  75. step (int): Step for overlapped sliding window.
  76. thresh_size (int): Threshold size. Patches whose size is lower
  77. than thresh_size will be dropped.
  78. save_folder (str): Path to save folder.
  79. compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
  80. Returns:
  81. process_info (str): Process information displayed in progress bar.
  82. """
  83. crop_size = opt['crop_size']
  84. step = opt['step']
  85. thresh_size = opt['thresh_size']
  86. img_name, extension = osp.splitext(osp.basename(path))
  87. # remove the x2, x3, x4 and x8 in the filename for DIV2K
  88. img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
  89. img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  90. h, w = img.shape[0:2]
  91. h_space = np.arange(0, h - crop_size + 1, step)
  92. if h - (h_space[-1] + crop_size) > thresh_size:
  93. h_space = np.append(h_space, h - crop_size)
  94. w_space = np.arange(0, w - crop_size + 1, step)
  95. if w - (w_space[-1] + crop_size) > thresh_size:
  96. w_space = np.append(w_space, w - crop_size)
  97. index = 0
  98. for x in h_space:
  99. for y in w_space:
  100. index += 1
  101. cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
  102. cropped_img = np.ascontiguousarray(cropped_img)
  103. cv2.imwrite(
  104. osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
  105. [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
  106. process_info = f'Processing {img_name} ...'
  107. return process_info
  108. if __name__ == '__main__':
  109. parser = argparse.ArgumentParser()
  110. parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
  111. parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
  112. parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
  113. parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
  114. parser.add_argument(
  115. '--thresh_size',
  116. type=int,
  117. default=0,
  118. help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
  119. parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
  120. parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
  121. args = parser.parse_args()
  122. main(args)