| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from extra.datasets.kits19 import iterate, preprocess
- from examples.mlperf.dataloader import batch_load_unet3d
- from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
- from tinygrad.helpers import temp
- from pathlib import Path
- import nibabel as nib
- import numpy as np
- import os
- import random
- import tempfile
- import unittest
- class ExternalTestDatasets(unittest.TestCase):
- def _set_seed(self):
- np.random.seed(42)
- random.seed(42)
- def _create_samples(self, val, num_samples=2):
- self._set_seed()
- img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
- img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
- dataset = "val" if val else "train"
- preproc_pth = Path(tempfile.gettempdir() + f"/{dataset}")
- for i in range(num_samples):
- os.makedirs(tempfile.gettempdir() + f"/case_000{i}", exist_ok=True)
- nib.save(img, temp(f"case_000{i}/imaging.nii.gz"))
- nib.save(lbl, temp(f"case_000{i}/segmentation.nii.gz"))
- preproc_img, preproc_lbl = preprocess(Path(tempfile.gettempdir()) / f"case_000{i}")
- preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_000{i}_x.npy"), temp(f"{dataset}/case_000{i}_y.npy")
- os.makedirs(preproc_pth, exist_ok=True)
- np.save(preproc_img_pth, preproc_img, allow_pickle=False)
- np.save(preproc_lbl_pth, preproc_lbl, allow_pickle=False)
- return preproc_pth, list(preproc_pth.glob("*_x.npy")), list(preproc_pth.glob("*_y.npy"))
- def _create_kits19_ref_dataloader(self, preproc_img_pths, preproc_lbl_pths, val):
- if val:
- dataset = PytVal(preproc_img_pths, preproc_lbl_pths)
- else:
- dataset = PytTrain(preproc_img_pths, preproc_lbl_pths, patch_size=(128, 128, 128), oversampling=0.4)
- return iter(dataset)
- def _create_kits19_tinygrad_dataloader(self, preproc_pth, val, batch_size=1, shuffle=False, seed=42, use_old_dataloader=False):
- if use_old_dataloader:
- dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
- else:
- dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
- return iter(dataset)
- def test_kits19_training_set(self):
- preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(False)
- ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, False)
- tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, False)
- for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
- self._set_seed()
- np.testing.assert_equal(tinygrad_sample[0][:, 0].numpy(), ref_sample[0])
- np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
- def test_kits19_validation_set(self):
- _, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
- ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
- tinygrad_dataset = self._create_kits19_tinygrad_dataloader(Path(tempfile.gettempdir()), True, use_old_dataloader=True)
- for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
- np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])
- np.testing.assert_equal(tinygrad_sample[1], ref_sample[1])
- if __name__ == '__main__':
- unittest.main()
|