dataloader.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import os, random, pickle, functools, itertools
  2. from typing import List, Tuple
  3. from pathlib import Path
  4. import numpy as np
  5. from PIL import Image
  6. from tqdm import tqdm
  7. from tinygrad import dtypes, Tensor
  8. from tinygrad.helpers import getenv, prod, Context, round_up
  9. from collections import deque
  10. from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
  11. class MyQueue:
  12. def __init__(self, multiple_readers=True, multiple_writers=True):
  13. self._reader, self._writer = connection.Pipe(duplex=False)
  14. self._rlock = Lock() if multiple_readers else None
  15. self._wlock = Lock() if multiple_writers else None
  16. def get(self):
  17. if self._rlock: self._rlock.acquire()
  18. ret = pickle.loads(self._reader.recv_bytes())
  19. if self._rlock: self._rlock.release()
  20. return ret
  21. def put(self, obj):
  22. if self._wlock: self._wlock.acquire()
  23. self._writer.send_bytes(pickle.dumps(obj))
  24. if self._wlock: self._wlock.release()
  25. def shuffled_indices(n, seed=None):
  26. rng = random.Random(seed)
  27. indices = {}
  28. for i in range(n-1, -1, -1):
  29. j = rng.randint(0, i)
  30. if i not in indices: indices[i] = i
  31. if j not in indices: indices[j] = j
  32. indices[i], indices[j] = indices[j], indices[i]
  33. yield indices[i]
  34. del indices[i]
  35. def loader_process(q_in, q_out, X:Tensor, seed):
  36. import signal
  37. signal.signal(signal.SIGINT, lambda _, __: exit(0))
  38. from extra.datasets.imagenet import center_crop, preprocess_train
  39. with Context(DEBUG=0):
  40. while (_recv := q_in.get()) is not None:
  41. idx, fn, val = _recv
  42. if fn is not None:
  43. img = Image.open(fn)
  44. img = img.convert('RGB') if img.mode != "RGB" else img
  45. if val:
  46. # eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
  47. # sudo apt-get install libjpeg-dev
  48. # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
  49. img = center_crop(img)
  50. img = np.array(img)
  51. else:
  52. # reseed rng for determinism
  53. if seed is not None:
  54. np.random.seed(seed * 2 ** 10 + idx)
  55. random.seed(seed * 2 ** 10 + idx)
  56. img = preprocess_train(img)
  57. else:
  58. # pad data with training mean
  59. img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
  60. # broken out
  61. #img_tensor = Tensor(img.tobytes(), device='CPU')
  62. #storage_tensor = X[idx].contiguous().realize().lazydata.realized
  63. #storage_tensor._copyin(img_tensor.numpy())
  64. # faster
  65. X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
  66. # ideal
  67. #X[idx].assign(img.tobytes()) # NOTE: this is slow!
  68. q_out.put(idx)
  69. q_out.put(None)
  70. def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
  71. from extra.datasets.imagenet import get_train_files, get_val_files
  72. files = get_val_files() if val else get_train_files()
  73. from extra.datasets.imagenet import get_imagenet_categories
  74. cir = get_imagenet_categories()
  75. if pad_first_batch:
  76. FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
  77. else:
  78. FIRST_BATCH_PAD = 0
  79. file_count = FIRST_BATCH_PAD + len(files)
  80. BATCH_COUNT = min(32, file_count // batch_size)
  81. def _gen():
  82. for _ in range(FIRST_BATCH_PAD): yield -1
  83. yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
  84. gen = iter(_gen())
  85. def enqueue_batch(num):
  86. for idx in range(num*batch_size, (num+1)*batch_size):
  87. fidx = next(gen)
  88. if fidx != -1:
  89. fn = files[fidx]
  90. q_in.put((idx, fn, val))
  91. Y[idx] = cir[fn.split("/")[-2]]
  92. else:
  93. # padding
  94. q_in.put((idx, None, val))
  95. Y[idx] = -1
  96. shutdown = False
  97. class Cookie:
  98. def __init__(self, num): self.num = num
  99. def __del__(self):
  100. if not shutdown:
  101. try: enqueue_batch(self.num)
  102. except StopIteration: pass
  103. gotten = [0]*BATCH_COUNT
  104. def receive_batch():
  105. while 1:
  106. num = q_out.get()//batch_size
  107. gotten[num] += 1
  108. if gotten[num] == batch_size: break
  109. gotten[num] = 0
  110. return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
  111. #q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
  112. q_in, q_out = Queue(), Queue()
  113. sz = (batch_size*BATCH_COUNT, 224, 224, 3)
  114. if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
  115. shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
  116. procs = []
  117. try:
  118. # disk:shm is slower
  119. #X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
  120. X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
  121. Y = [None] * (batch_size*BATCH_COUNT)
  122. for _ in range(cpu_count()):
  123. p = Process(target=loader_process, args=(q_in, q_out, X, seed))
  124. p.daemon = True
  125. p.start()
  126. procs.append(p)
  127. for bn in range(BATCH_COUNT): enqueue_batch(bn)
  128. # NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
  129. for _ in range(0, file_count//batch_size): yield receive_batch()
  130. finally:
  131. shutdown = True
  132. # empty queues
  133. for _ in procs: q_in.put(None)
  134. q_in.close()
  135. for _ in procs:
  136. while q_out.get() is not None: pass
  137. q_out.close()
  138. # shutdown processes
  139. for p in procs: p.join()
  140. shm.close()
  141. try:
  142. shm.unlink()
  143. except FileNotFoundError:
  144. # happens with BENCHMARK set
  145. pass
  146. @functools.lru_cache(maxsize=128)
  147. def load_bert_file(fn:str) -> List[dict]:
  148. with open(fn, "rb") as f: data = pickle.load(f)
  149. return data
  150. def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
  151. return {
  152. "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
  153. "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float),
  154. "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
  155. "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
  156. "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
  157. "masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32),
  158. "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
  159. }
  160. def shuffle_parts(file_paths: List[str]) -> List[str]:
  161. parts = {}
  162. for f in file_paths:
  163. part = Path(f).stem.split('_')[0]
  164. if part not in parts: parts[part] = []
  165. parts[part].append(f)
  166. part_ids = list(parts.keys())
  167. random.shuffle(part_ids)
  168. shuffled_files = []
  169. for p in part_ids:
  170. parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
  171. shuffled_files.extend(parts[p])
  172. return shuffled_files
  173. def random_sample(data: List[str]):
  174. index = random.randint(0, len(data) - 1)
  175. selected_sample = data[index]
  176. return selected_sample, index
  177. def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
  178. data = load_bert_file(file_and_offset[0])
  179. return data[file_and_offset[1]]
  180. # Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
  181. def batch_load_train_bert(BS:int, start_step:int = 0):
  182. from extra.datasets.wikipedia import get_wiki_train_files
  183. files = shuffle_parts(get_wiki_train_files())
  184. dataset = []
  185. for f in tqdm(files, desc="Building dataset"):
  186. lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
  187. dataset.extend(lists)
  188. dataset = dataset[start_step:]
  189. active_set = deque(dataset[:1000])
  190. remaining_set = deque(dataset[1000:])
  191. while dataset:
  192. blob = []
  193. for _ in range(BS):
  194. if active_set:
  195. index = random.randint(0, len(active_set) - 1)
  196. sample = active_set[index]
  197. active_set.remove(sample)
  198. blob.append(sample)
  199. if remaining_set:
  200. active_set.append(remaining_set.popleft())
  201. yield process_batch_bert([load_datasample(sample) for sample in blob])
  202. # Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
  203. def batch_load_val_bert(BS:int):
  204. from extra.datasets.wikipedia import get_wiki_val_files
  205. files = get_wiki_val_files()
  206. dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
  207. idx = 0
  208. while True:
  209. start_idx = (idx * BS) % len(dataset)
  210. end_idx = ((idx + 1) * BS) % len(dataset)
  211. if start_idx < end_idx:
  212. yield process_batch_bert(dataset[start_idx:end_idx])
  213. else: # wrap around the end to the beginning of the dataset
  214. yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
  215. idx += 1
  216. def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
  217. from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise
  218. while (data := queue_in.get()) is not None:
  219. idx, fn, val = data
  220. case_name = os.path.basename(fn).split("_x.npy")[0]
  221. x, y = np.load(preprocessed_dataset_dir / f"{case_name}_x.npy"), np.load(preprocessed_dataset_dir / f"{case_name}_y.npy")
  222. if not val:
  223. if seed is not None:
  224. np.random.seed(seed)
  225. random.seed(seed)
  226. x, y = rand_balanced_crop(x, y)
  227. x, y = rand_flip(x, y)
  228. x, y = x.astype(np.float32), y.astype(np.uint8)
  229. x = random_brightness_augmentation(x)
  230. x = gaussian_noise(x)
  231. X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
  232. Y[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
  233. queue_out.put(idx)
  234. queue_out.put(None)
  235. def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=False, shuffle:bool=True, seed=None):
  236. assert preprocessed_dataset_dir is not None, "run preprocess_data on kits19"
  237. files = sorted(list(preprocessed_dataset_dir.glob("*_x.npy")))
  238. file_indices = list(range(len(files)))
  239. batch_count = min(32, len(files) // batch_size)
  240. queue_in, queue_out = Queue(), Queue()
  241. procs, data_out_count = [], [0] * batch_count
  242. shm_name_x, shm_name_y = "unet3d_x", "unet3d_y"
  243. sz = (batch_size * batch_count, 1, 128, 128, 128)
  244. if os.path.exists(f"/dev/shm/{shm_name_x}"): os.unlink(f"/dev/shm/{shm_name_x}")
  245. if os.path.exists(f"/dev/shm/{shm_name_y}"): os.unlink(f"/dev/shm/{shm_name_y}")
  246. shm_x = shared_memory.SharedMemory(name=shm_name_x, create=True, size=prod(sz))
  247. shm_y = shared_memory.SharedMemory(name=shm_name_y, create=True, size=prod(sz))
  248. shutdown = False
  249. class Cookie:
  250. def __init__(self, bc):
  251. self.bc = bc
  252. def __del__(self):
  253. if not shutdown:
  254. try: enqueue_batch(self.bc)
  255. except StopIteration: pass
  256. def enqueue_batch(bc):
  257. for idx in range(bc * batch_size, (bc+1) * batch_size):
  258. fn = files[next(ds_iter)]
  259. queue_in.put((idx, fn, val))
  260. def shuffle_indices(file_indices, seed=None):
  261. rng = random.Random(seed)
  262. rng.shuffle(file_indices)
  263. if shuffle: shuffle_indices(file_indices, seed=seed)
  264. ds_iter = iter(file_indices)
  265. try:
  266. X = Tensor.empty(*sz, dtype=dtypes.float32, device=f"disk:/dev/shm/{shm_name_x}")
  267. Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_y}")
  268. for _ in range(cpu_count()):
  269. proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
  270. proc.daemon = True
  271. proc.start()
  272. procs.append(proc)
  273. for bc in range(batch_count):
  274. enqueue_batch(bc)
  275. for _ in range(len(files) // batch_size):
  276. while True:
  277. bc = queue_out.get() // batch_size
  278. data_out_count[bc] += 1
  279. if data_out_count[bc] == batch_size: break
  280. data_out_count[bc] = 0
  281. yield X[bc * batch_size:(bc + 1) * batch_size], Y[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
  282. finally:
  283. shutdown = True
  284. for _ in procs: queue_in.put(None)
  285. queue_in.close()
  286. for _ in procs:
  287. while queue_out.get() is not None: pass
  288. queue_out.close()
  289. # shutdown processes
  290. for proc in procs: proc.join()
  291. shm_x.close()
  292. shm_y.close()
  293. try:
  294. shm_x.unlink()
  295. shm_y.unlink()
  296. except FileNotFoundError:
  297. # happens with BENCHMARK set
  298. pass
  299. if __name__ == "__main__":
  300. def load_unet3d(val):
  301. assert not val, "validation set is not supported due to different sizes on inputs"
  302. from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, BASEDIR
  303. preprocessed_dataset_dir = (BASEDIR / ".." / "preprocessed" / ("val" if val else "train"))
  304. files = get_val_files() if val else get_train_files()
  305. if not preprocessed_dataset_dir.exists(): preprocess_dataset(files, preprocessed_dataset_dir, val)
  306. with tqdm(total=len(files)) as pbar:
  307. for x, _, _ in batch_load_unet3d(preprocessed_dataset_dir, val=val):
  308. pbar.update(x.shape[0])
  309. def load_resnet(val):
  310. from extra.datasets.imagenet import get_train_files, get_val_files
  311. files = get_val_files() if val else get_train_files()
  312. with tqdm(total=len(files)) as pbar:
  313. for x,y,c in batch_load_resnet(val=val):
  314. pbar.update(x.shape[0])
  315. load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
  316. if load_fn_name in globals():
  317. globals()[load_fn_name](getenv("VAL", 1))