external_test_example.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import unittest
  2. from tinygrad import Device
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.helpers import getenv, CI
  5. def multidevice_test(fxn):
  6. exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
  7. def ret(self):
  8. for device in Device._devices:
  9. if device in ["DISK", "NPY", "FAKE"]: continue
  10. if not CI: print(device)
  11. if device in exclude_devices:
  12. if not CI: print(f"WARNING: {device} test is excluded")
  13. continue
  14. with self.subTest(device=device):
  15. try:
  16. Device[device]
  17. except Exception:
  18. if not CI: print(f"WARNING: {device} test isn't running")
  19. continue
  20. fxn(self, device)
  21. return ret
  22. class TestExample(unittest.TestCase):
  23. @multidevice_test
  24. def test_convert_to_clang(self, device):
  25. a = Tensor([[1,2],[3,4]], device=device)
  26. assert a.numpy().shape == (2,2)
  27. b = a.clang()
  28. assert b.numpy().shape == (2,2)
  29. @multidevice_test
  30. def test_2_plus_3(self, device):
  31. a = Tensor([2], device=device)
  32. b = Tensor([3], device=device)
  33. result = a + b
  34. print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
  35. assert result.numpy()[0] == 5.
  36. @multidevice_test
  37. def test_example_readme(self, device):
  38. x = Tensor.eye(3, device=device, requires_grad=True)
  39. y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True)
  40. z = y.matmul(x).sum()
  41. z.backward()
  42. x.grad.numpy() # dz/dx
  43. y.grad.numpy() # dz/dy
  44. assert x.grad.device == device
  45. assert y.grad.device == device
  46. @multidevice_test
  47. def test_example_matmul(self, device):
  48. try:
  49. Device[device]
  50. except Exception:
  51. print(f"WARNING: {device} test isn't running")
  52. return
  53. x = Tensor.eye(64, device=device, requires_grad=True)
  54. y = Tensor.eye(64, device=device, requires_grad=True)
  55. z = y.matmul(x).sum()
  56. z.backward()
  57. x.grad.numpy() # dz/dx
  58. y.grad.numpy() # dz/dy
  59. assert x.grad.device == device
  60. assert y.grad.device == device
  61. if __name__ == '__main__':
  62. unittest.main()