extract_subimages.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import argparse
  2. import cv2
  3. import numpy as np
  4. import os
  5. import sys
  6. from basicsr.utils import scandir
  7. from multiprocessing import Pool
  8. from os import path as osp
  9. from tqdm import tqdm
  10. def main(args):
  11. """A multi-thread tool to crop large images to sub-images for faster IO.
  12. opt (dict): Configuration dict. It contains:
  13. n_thread (int): Thread number.
  14. compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
  15. and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
  16. input_folder (str): Path to the input folder.
  17. save_folder (str): Path to save folder.
  18. crop_size (int): Crop size.
  19. step (int): Step for overlapped sliding window.
  20. thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
  21. Usage:
  22. For each folder, run this script.
  23. Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
  24. After process, each sub_folder should have the same number of subimages.
  25. Remember to modify opt configurations according to your settings.
  26. """
  27. opt = {}
  28. opt['n_thread'] = args.n_thread
  29. opt['compression_level'] = args.compression_level
  30. opt['input_folder'] = args.input
  31. opt['save_folder'] = args.output
  32. opt['crop_size'] = args.crop_size
  33. opt['step'] = args.step
  34. opt['thresh_size'] = args.thresh_size
  35. extract_subimages(opt)
  36. def extract_subimages(opt):
  37. """Crop images to subimages.
  38. Args:
  39. opt (dict): Configuration dict. It contains:
  40. input_folder (str): Path to the input folder.
  41. save_folder (str): Path to save folder.
  42. n_thread (int): Thread number.
  43. """
  44. input_folder = opt['input_folder']
  45. save_folder = opt['save_folder']
  46. if not osp.exists(save_folder):
  47. os.makedirs(save_folder)
  48. print(f'mkdir {save_folder} ...')
  49. else:
  50. print(f'Folder {save_folder} already exists. Exit.')
  51. sys.exit(1)
  52. # scan all images
  53. img_list = list(scandir(input_folder, full_path=True))
  54. pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
  55. pool = Pool(opt['n_thread'])
  56. for path in img_list:
  57. pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
  58. pool.close()
  59. pool.join()
  60. pbar.close()
  61. print('All processes done.')
  62. def worker(path, opt):
  63. """Worker for each process.
  64. Args:
  65. path (str): Image path.
  66. opt (dict): Configuration dict. It contains:
  67. crop_size (int): Crop size.
  68. step (int): Step for overlapped sliding window.
  69. thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
  70. save_folder (str): Path to save folder.
  71. compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
  72. Returns:
  73. process_info (str): Process information displayed in progress bar.
  74. """
  75. crop_size = opt['crop_size']
  76. step = opt['step']
  77. thresh_size = opt['thresh_size']
  78. img_name, extension = osp.splitext(osp.basename(path))
  79. # remove the x2, x3, x4 and x8 in the filename for DIV2K
  80. img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
  81. img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  82. h, w = img.shape[0:2]
  83. h_space = np.arange(0, h - crop_size + 1, step)
  84. if h - (h_space[-1] + crop_size) > thresh_size:
  85. h_space = np.append(h_space, h - crop_size)
  86. w_space = np.arange(0, w - crop_size + 1, step)
  87. if w - (w_space[-1] + crop_size) > thresh_size:
  88. w_space = np.append(w_space, w - crop_size)
  89. index = 0
  90. for x in h_space:
  91. for y in w_space:
  92. index += 1
  93. cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
  94. cropped_img = np.ascontiguousarray(cropped_img)
  95. cv2.imwrite(
  96. osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
  97. [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
  98. process_info = f'Processing {img_name} ...'
  99. return process_info
  100. if __name__ == '__main__':
  101. parser = argparse.ArgumentParser()
  102. parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
  103. parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
  104. parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
  105. parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
  106. parser.add_argument(
  107. '--thresh_size',
  108. type=int,
  109. default=0,
  110. help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
  111. parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
  112. parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
  113. args = parser.parse_args()
  114. main(args)