external_test_losses.py 837 B

1234567891011121314151617181920
  1. from tinygrad import Tensor
  2. from test.external.mlperf_unet3d.dice import DiceCELoss
  3. from examples.mlperf.losses import dice_ce_loss
  4. import numpy as np
  5. import torch
  6. import unittest
  7. class ExternalTestLosses(unittest.TestCase):
  8. def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
  9. tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
  10. orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
  11. np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
  12. def test_dice_ce(self):
  13. pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
  14. self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
  15. if __name__ == '__main__':
  16. unittest.main()