| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- import unittest
- import pickle
- from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear
- def remote_get(table,q,k): q.put(diskcache_get(table, k))
- def remote_put(table,k,v): diskcache_put(table, k, v)
- class DiskCache(unittest.TestCase):
- def test_putget(self):
- table = "test_putget"
- diskcache_put(table, "hello", "world")
- self.assertEqual(diskcache_get(table, "hello"), "world")
- diskcache_put(table, "hello", "world2")
- self.assertEqual(diskcache_get(table, "hello"), "world2")
- def test_putcomplex(self):
- table = "test_putcomplex"
- diskcache_put(table, "k", ("complex", 123, "object"))
- ret = diskcache_get(table, "k")
- self.assertEqual(ret, ("complex", 123, "object"))
- def test_getotherprocess(self):
- table = "test_getotherprocess"
- from multiprocessing import Process, Queue
- diskcache_put(table, "k", "getme")
- q = Queue()
- p = Process(target=remote_get, args=(table,q,"k"))
- p.start()
- p.join()
- self.assertEqual(q.get(), "getme")
- def test_putotherprocess(self):
- table = "test_putotherprocess"
- from multiprocessing import Process
- p = Process(target=remote_put, args=(table,"k", "remote"))
- p.start()
- p.join()
- self.assertEqual(diskcache_get(table, "k"), "remote")
- def test_no_table(self):
- self.assertIsNone(diskcache_get("faketable", "k"))
- def test_ret(self):
- table = "test_ret"
- self.assertEqual(diskcache_put(table, "key", ("vvs",)), ("vvs",))
- def test_non_str_key(self):
- table = "test_non_str_key"
- diskcache_put(table, 4, 5)
- self.assertEqual(diskcache_get(table, 4), 5)
- self.assertEqual(diskcache_get(table, "4"), 5)
- def test_decorator(self):
- calls = 0
- @diskcache
- def hello(x):
- nonlocal calls
- calls += 1
- return "world"+x
- self.assertEqual(hello("bob"), "worldbob")
- self.assertEqual(hello("billy"), "worldbilly")
- kcalls = calls
- self.assertEqual(hello("bob"), "worldbob")
- self.assertEqual(hello("billy"), "worldbilly")
- self.assertEqual(kcalls, calls)
- def test_dict_key(self):
- table = "test_dict_key"
- fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
- fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
- fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
- diskcache_put(table, fancy_key, 5)
- self.assertEqual(diskcache_get(table, fancy_key), 5)
- diskcache_put(table, fancy_key2, 8)
- self.assertEqual(diskcache_get(table, fancy_key2), 8)
- self.assertEqual(diskcache_get(table, fancy_key), 5)
- self.assertEqual(diskcache_get(table, fancy_key3), None)
- def test_table_name(self):
- table = "test_gfx1010:xnack-"
- diskcache_put(table, "key", "test")
- self.assertEqual(diskcache_get(table, "key"), "test")
- @unittest.skip("disabled by default because this drops cache table")
- def test_clear_cache(self):
- # clear cache to start
- diskcache_clear()
- tables = [f"test_clear_cache:{i}" for i in range(3)]
- for table in tables:
- # check no entries
- self.assertIsNone(diskcache_get(table, "k"))
- for table in tables:
- diskcache_put(table, "k", "test")
- # check insertion
- self.assertEqual(diskcache_get(table, "k"), "test")
- diskcache_clear()
- for table in tables:
- # check no entries again
- self.assertIsNone(diskcache_get(table, "k"))
- # calling multiple times is fine
- diskcache_clear()
- diskcache_clear()
- diskcache_clear()
- if __name__ == "__main__":
- unittest.main()
|