test_device.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #!/usr/bin/env python
  2. import unittest
  3. from unittest.mock import patch
  4. import os
  5. from tinygrad.device import Device, Compiler
  6. from tinygrad.helpers import diskcache_get, diskcache_put, getenv
  7. class TestDevice(unittest.TestCase):
  8. def test_canonicalize(self):
  9. assert Device.canonicalize(None) == Device.DEFAULT
  10. assert Device.canonicalize("CPU") == "CPU"
  11. assert Device.canonicalize("cpu") == "CPU"
  12. assert Device.canonicalize("GPU") == "GPU"
  13. assert Device.canonicalize("GPU:0") == "GPU"
  14. assert Device.canonicalize("gpu:0") == "GPU"
  15. assert Device.canonicalize("GPU:1") == "GPU:1"
  16. assert Device.canonicalize("gpu:1") == "GPU:1"
  17. assert Device.canonicalize("GPU:2") == "GPU:2"
  18. assert Device.canonicalize("disk:/dev/shm/test") == "DISK:/dev/shm/test"
  19. class MockCompiler(Compiler):
  20. def __init__(self, key): super().__init__(key)
  21. def compile(self, src) -> bytes: return src.encode()
  22. class TestCompiler(unittest.TestCase):
  23. def test_compile_cached(self):
  24. diskcache_put("key", "123", None) # clear cache
  25. getenv.cache_clear()
  26. with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
  27. assert MockCompiler("key").compile_cached("123") == str.encode("123")
  28. assert diskcache_get("key", "123") == str.encode("123")
  29. def test_compile_cached_disabled(self):
  30. diskcache_put("disabled_key", "123", None) # clear cache
  31. getenv.cache_clear()
  32. with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
  33. assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
  34. assert diskcache_get("disabled_key", "123") is None
  35. if __name__ == "__main__":
  36. unittest.main()