kits19.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import random
  2. import functools
  3. from pathlib import Path
  4. import numpy as np
  5. import nibabel as nib
  6. from scipy import signal, ndimage
  7. import os
  8. import torch
  9. import torch.nn.functional as F
  10. from tqdm import tqdm
  11. from tinygrad.tensor import Tensor
  12. from tinygrad.helpers import fetch
  13. BASEDIR = Path(__file__).parent / "kits19" / "data"
  14. PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
  15. """
  16. To download the dataset:
  17. ```sh
  18. git clone https://github.com/neheller/kits19
  19. cd kits19
  20. pip3 install -r requirements.txt
  21. python3 -m starter_code.get_imaging
  22. cd ..
  23. mv kits19 extra/datasets
  24. ```
  25. """
  26. @functools.lru_cache(None)
  27. def get_train_files():
  28. return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
  29. @functools.lru_cache(None)
  30. def get_val_files():
  31. data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
  32. return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
  33. def load_pair(file_path):
  34. image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
  35. image_spacings = image.header["pixdim"][1:4].tolist()
  36. image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
  37. image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
  38. return image, label, image_spacings
  39. def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
  40. if image_spacings != target_spacing:
  41. spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
  42. new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
  43. image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
  44. label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
  45. image = np.squeeze(image.numpy(), axis=0)
  46. label = np.squeeze(label.numpy(), axis=0)
  47. return image, label
  48. def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
  49. image = np.clip(image, min_clip, max_clip)
  50. image = (image - mean) / std
  51. return image
  52. def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
  53. current_shape = image.shape[1:]
  54. bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
  55. paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
  56. image = np.pad(image, paddings, mode="edge")
  57. label = np.pad(label, paddings, mode="edge")
  58. return image, label
  59. def preprocess(file_path):
  60. image, label, image_spacings = load_pair(file_path)
  61. image, label = resample3d(image, label, image_spacings)
  62. image = normal_intensity(image.copy())
  63. image, label = pad_to_min_shape(image, label)
  64. return image, label
  65. def preprocess_dataset(filenames, preprocessed_dir, val):
  66. preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
  67. if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
  68. for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
  69. case = os.path.basename(fn)
  70. image, label = preprocess(fn)
  71. image, label = image.astype(np.float32), label.astype(np.uint8)
  72. np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
  73. np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
  74. def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
  75. order = list(range(0, len(files)))
  76. preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
  77. if shuffle: random.shuffle(order)
  78. for i in range(0, len(files), bs):
  79. samples = []
  80. for i in order[i:i+bs]:
  81. if preprocessed_dataset_dir is not None:
  82. x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
  83. if x_cached_path.exists() and y_cached_path.exists():
  84. samples += [(np.load(x_cached_path), np.load(y_cached_path))]
  85. else: samples += [preprocess(files[i])]
  86. X, Y = [x[0] for x in samples], [x[1] for x in samples]
  87. if val:
  88. yield X[0][None], Y[0]
  89. else:
  90. X_preprocessed, Y_preprocessed = [], []
  91. for x, y in zip(X, Y):
  92. x, y = rand_balanced_crop(x, y)
  93. x, y = rand_flip(x, y)
  94. x, y = x.astype(np.float32), y.astype(np.uint8)
  95. x = random_brightness_augmentation(x)
  96. x = gaussian_noise(x)
  97. X_preprocessed.append(x)
  98. Y_preprocessed.append(y)
  99. yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
  100. def gaussian_kernel(n, std):
  101. gaussian_1d = signal.windows.gaussian(n, std)
  102. gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
  103. gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
  104. gaussian_3d = gaussian_3d.reshape(n, n, n)
  105. gaussian_3d = np.cbrt(gaussian_3d)
  106. gaussian_3d /= gaussian_3d.max()
  107. return gaussian_3d
  108. def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
  109. bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
  110. bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
  111. paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
  112. return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
  113. def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
  114. from tinygrad.engine.jit import TinyJit
  115. mdl_run = TinyJit(lambda x: model(x).realize())
  116. image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
  117. strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
  118. bounds = [image_shape[i] % strides[i] for i in range(dim)]
  119. bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
  120. inputs = inputs[
  121. ...,
  122. bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
  123. bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
  124. bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
  125. ]
  126. labels = labels[
  127. ...,
  128. bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
  129. bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
  130. bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
  131. ]
  132. inputs, paddings = pad_input(inputs, roi_shape, strides)
  133. padded_shape = inputs.shape[2:]
  134. size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
  135. result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
  136. norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
  137. norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
  138. norm_patch = np.expand_dims(norm_patch, axis=0)
  139. for i in range(0, strides[0] * size[0], strides[0]):
  140. for j in range(0, strides[1] * size[1], strides[1]):
  141. for k in range(0, strides[2] * size[2], strides[2]):
  142. out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
  143. result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
  144. norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
  145. result /= norm_map
  146. result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
  147. return result, labels
  148. def rand_flip(image, label, axis=(1, 2, 3)):
  149. prob = 1 / len(axis)
  150. for ax in axis:
  151. if random.random() < prob:
  152. image = np.flip(image, axis=ax).copy()
  153. label = np.flip(label, axis=ax).copy()
  154. return image, label
  155. def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
  156. if random.random() < prob:
  157. factor = np.random.uniform(low=low, high=high, size=1)
  158. image = (image * (1 + factor)).astype(image.dtype)
  159. return image
  160. def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
  161. if random.random() < prob:
  162. scale = np.random.uniform(low=0.0, high=std)
  163. noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
  164. image += noise
  165. return image
  166. def _rand_foreg_cropb(image, label, patch_size):
  167. def adjust(foreg_slice, label, idx):
  168. diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
  169. sign = -1 if diff < 0 else 1
  170. diff = abs(diff)
  171. ladj = 0 if diff == 0 else random.randrange(diff)
  172. hadj = diff - ladj
  173. low = max(0, foreg_slice[idx].start - sign * ladj)
  174. high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
  175. diff = patch_size[idx - 1] - (high - low)
  176. if diff > 0 and low == 0: high += diff
  177. elif diff > 0: low -= diff
  178. return low, high
  179. cl = np.random.choice(np.unique(label[label > 0]))
  180. foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
  181. foreg_slices = [x for x in foreg_slices if x is not None]
  182. slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
  183. slice_idx = np.argsort(slice_volumes)[-2:]
  184. foreg_slices = [foreg_slices[i] for i in slice_idx]
  185. if not foreg_slices: return _rand_crop(image, label)
  186. foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
  187. low_x, high_x = adjust(foreg_slice, label, 1)
  188. low_y, high_y = adjust(foreg_slice, label, 2)
  189. low_z, high_z = adjust(foreg_slice, label, 3)
  190. image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
  191. label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
  192. return image, label
  193. def _rand_crop(image, label, patch_size):
  194. ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
  195. cord = [0 if x == 0 else random.randrange(x) for x in ranges]
  196. low_x, high_x = cord[0], cord[0] + patch_size[0]
  197. low_y, high_y = cord[1], cord[1] + patch_size[1]
  198. low_z, high_z = cord[2], cord[2] + patch_size[2]
  199. image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
  200. label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
  201. return image, label
  202. def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
  203. if random.random() < oversampling:
  204. image, label = _rand_foreg_cropb(image, label, patch_size)
  205. else:
  206. image, label = _rand_crop(image, label, patch_size)
  207. return image, label
  208. if __name__ == "__main__":
  209. for X, Y in iterate(get_val_files()):
  210. print(X.shape, Y.shape)