test_pickle.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import unittest, pickle
  2. import numpy as np
  3. from tinygrad import Tensor, TinyJit, Variable
  4. from tinygrad.engine.schedule import create_schedule
  5. class TestPickle(unittest.TestCase):
  6. def test_pickle_realized_tensor(self):
  7. t = Tensor.rand(10, 10).realize()
  8. st = pickle.dumps(t)
  9. t2:Tensor = pickle.loads(st)
  10. np.testing.assert_equal(t.numpy(), t2.numpy())
  11. def test_pickle_unrealized_tensor(self):
  12. t = Tensor.ones(10, 10)
  13. st = pickle.dumps(t)
  14. t2:Tensor = pickle.loads(st)
  15. np.testing.assert_equal(t.numpy(), t2.numpy())
  16. def test_pickle_variable(self):
  17. v = Variable("i", 1, 20).bind(10)
  18. t1 = Tensor.ones(10, v).contiguous()
  19. t2 = Tensor.ones(10, v).contiguous()
  20. ret = (t1+t2).sum(1)
  21. st = pickle.dumps(ret)
  22. del ret
  23. vt2 = pickle.loads(st)
  24. np.testing.assert_equal(vt2.numpy(), 20)
  25. def test_pickle_buffer_view(self):
  26. t = Tensor.arange(10, device="CLANG").contiguous().realize()
  27. vt = t[3:5].contiguous().realize()
  28. assert hasattr(vt.lazydata.buffer, 'base')
  29. ref_value = vt.tolist()
  30. st = pickle.dumps(vt)
  31. del t, vt
  32. vt2 = pickle.loads(st)
  33. assert hasattr(vt2.lazydata.buffer, 'base')
  34. assert ref_value == vt2.tolist()
  35. def test_pickle_numpy(self):
  36. t = Tensor(np.array([1,2,3,4.]))
  37. st = pickle.dumps(t)
  38. t2:Tensor = pickle.loads(st)
  39. np.testing.assert_equal(t.numpy(), t2.numpy())
  40. def test_pickle_jit(self):
  41. @TinyJit
  42. def add(a, b): return a+b+1
  43. for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
  44. del add.fxn # pickling the JIT requires the function to be deleted
  45. st = pickle.dumps(add)
  46. del add
  47. add_fxn = pickle.loads(st)
  48. x = Tensor.ones(10, 10).contiguous().realize()
  49. y = Tensor.ones(10, 10).contiguous().realize()
  50. print("post jit")
  51. out = add_fxn(x, y)
  52. np.testing.assert_equal(out.numpy(), 3)
  53. def test_pickle_schedule(self):
  54. a = Tensor([1,2])
  55. out = a + 2
  56. sched = create_schedule([out.lazydata])
  57. pk = pickle.dumps(sched)
  58. sched_pk = pickle.loads(pk)
  59. assert sched_pk[-1].ast == sched[-1].ast
  60. if __name__ == '__main__':
  61. unittest.main()