external_test_datasets.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from extra.datasets.kits19 import iterate, preprocess
  2. from examples.mlperf.dataloader import batch_load_unet3d
  3. from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
  4. from tinygrad.helpers import temp
  5. from pathlib import Path
  6. import nibabel as nib
  7. import numpy as np
  8. import os
  9. import random
  10. import tempfile
  11. import unittest
  12. class ExternalTestDatasets(unittest.TestCase):
  13. def _set_seed(self):
  14. np.random.seed(42)
  15. random.seed(42)
  16. def _create_samples(self, val, num_samples=2):
  17. self._set_seed()
  18. img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
  19. img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
  20. dataset = "val" if val else "train"
  21. preproc_pth = Path(tempfile.gettempdir() + f"/{dataset}")
  22. for i in range(num_samples):
  23. os.makedirs(tempfile.gettempdir() + f"/case_000{i}", exist_ok=True)
  24. nib.save(img, temp(f"case_000{i}/imaging.nii.gz"))
  25. nib.save(lbl, temp(f"case_000{i}/segmentation.nii.gz"))
  26. preproc_img, preproc_lbl = preprocess(Path(tempfile.gettempdir()) / f"case_000{i}")
  27. preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_000{i}_x.npy"), temp(f"{dataset}/case_000{i}_y.npy")
  28. os.makedirs(preproc_pth, exist_ok=True)
  29. np.save(preproc_img_pth, preproc_img, allow_pickle=False)
  30. np.save(preproc_lbl_pth, preproc_lbl, allow_pickle=False)
  31. return preproc_pth, list(preproc_pth.glob("*_x.npy")), list(preproc_pth.glob("*_y.npy"))
  32. def _create_kits19_ref_dataloader(self, preproc_img_pths, preproc_lbl_pths, val):
  33. if val:
  34. dataset = PytVal(preproc_img_pths, preproc_lbl_pths)
  35. else:
  36. dataset = PytTrain(preproc_img_pths, preproc_lbl_pths, patch_size=(128, 128, 128), oversampling=0.4)
  37. return iter(dataset)
  38. def _create_kits19_tinygrad_dataloader(self, preproc_pth, val, batch_size=1, shuffle=False, seed=42, use_old_dataloader=False):
  39. if use_old_dataloader:
  40. dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
  41. else:
  42. dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
  43. return iter(dataset)
  44. def test_kits19_training_set(self):
  45. preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(False)
  46. ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, False)
  47. tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, False)
  48. for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
  49. self._set_seed()
  50. np.testing.assert_equal(tinygrad_sample[0][:, 0].numpy(), ref_sample[0])
  51. np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
  52. def test_kits19_validation_set(self):
  53. _, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
  54. ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
  55. tinygrad_dataset = self._create_kits19_tinygrad_dataloader(Path(tempfile.gettempdir()), True, use_old_dataloader=True)
  56. for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
  57. np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])
  58. np.testing.assert_equal(tinygrad_sample[1], ref_sample[1])
  59. if __name__ == '__main__':
  60. unittest.main()