test_disk_cache.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import unittest
  2. import pickle
  3. from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear
  4. def remote_get(table,q,k): q.put(diskcache_get(table, k))
  5. def remote_put(table,k,v): diskcache_put(table, k, v)
  6. class DiskCache(unittest.TestCase):
  7. def test_putget(self):
  8. table = "test_putget"
  9. diskcache_put(table, "hello", "world")
  10. self.assertEqual(diskcache_get(table, "hello"), "world")
  11. diskcache_put(table, "hello", "world2")
  12. self.assertEqual(diskcache_get(table, "hello"), "world2")
  13. def test_putcomplex(self):
  14. table = "test_putcomplex"
  15. diskcache_put(table, "k", ("complex", 123, "object"))
  16. ret = diskcache_get(table, "k")
  17. self.assertEqual(ret, ("complex", 123, "object"))
  18. def test_getotherprocess(self):
  19. table = "test_getotherprocess"
  20. from multiprocessing import Process, Queue
  21. diskcache_put(table, "k", "getme")
  22. q = Queue()
  23. p = Process(target=remote_get, args=(table,q,"k"))
  24. p.start()
  25. p.join()
  26. self.assertEqual(q.get(), "getme")
  27. def test_putotherprocess(self):
  28. table = "test_putotherprocess"
  29. from multiprocessing import Process
  30. p = Process(target=remote_put, args=(table,"k", "remote"))
  31. p.start()
  32. p.join()
  33. self.assertEqual(diskcache_get(table, "k"), "remote")
  34. def test_no_table(self):
  35. self.assertIsNone(diskcache_get("faketable", "k"))
  36. def test_ret(self):
  37. table = "test_ret"
  38. self.assertEqual(diskcache_put(table, "key", ("vvs",)), ("vvs",))
  39. def test_non_str_key(self):
  40. table = "test_non_str_key"
  41. diskcache_put(table, 4, 5)
  42. self.assertEqual(diskcache_get(table, 4), 5)
  43. self.assertEqual(diskcache_get(table, "4"), 5)
  44. def test_decorator(self):
  45. calls = 0
  46. @diskcache
  47. def hello(x):
  48. nonlocal calls
  49. calls += 1
  50. return "world"+x
  51. self.assertEqual(hello("bob"), "worldbob")
  52. self.assertEqual(hello("billy"), "worldbilly")
  53. kcalls = calls
  54. self.assertEqual(hello("bob"), "worldbob")
  55. self.assertEqual(hello("billy"), "worldbilly")
  56. self.assertEqual(kcalls, calls)
  57. def test_dict_key(self):
  58. table = "test_dict_key"
  59. fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
  60. fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
  61. fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
  62. diskcache_put(table, fancy_key, 5)
  63. self.assertEqual(diskcache_get(table, fancy_key), 5)
  64. diskcache_put(table, fancy_key2, 8)
  65. self.assertEqual(diskcache_get(table, fancy_key2), 8)
  66. self.assertEqual(diskcache_get(table, fancy_key), 5)
  67. self.assertEqual(diskcache_get(table, fancy_key3), None)
  68. def test_table_name(self):
  69. table = "test_gfx1010:xnack-"
  70. diskcache_put(table, "key", "test")
  71. self.assertEqual(diskcache_get(table, "key"), "test")
  72. @unittest.skip("disabled by default because this drops cache table")
  73. def test_clear_cache(self):
  74. # clear cache to start
  75. diskcache_clear()
  76. tables = [f"test_clear_cache:{i}" for i in range(3)]
  77. for table in tables:
  78. # check no entries
  79. self.assertIsNone(diskcache_get(table, "k"))
  80. for table in tables:
  81. diskcache_put(table, "k", "test")
  82. # check insertion
  83. self.assertEqual(diskcache_get(table, "k"), "test")
  84. diskcache_clear()
  85. for table in tables:
  86. # check no entries again
  87. self.assertIsNone(diskcache_get(table, "k"))
  88. # calling multiple times is fine
  89. diskcache_clear()
  90. diskcache_clear()
  91. diskcache_clear()
  92. if __name__ == "__main__":
  93. unittest.main()