test_graph.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import numpy as np
  2. import unittest, ctypes
  3. from tinygrad.device import Device, Buffer
  4. from tinygrad.tensor import Tensor, _to_np_dtype
  5. from tinygrad.engine.schedule import create_schedule
  6. from tinygrad.helpers import Context, CI, dedup, from_mv
  7. from tinygrad.dtype import dtypes
  8. from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner
  9. np.random.seed(1337)
  10. Tensor.manual_seed(1337)
  11. BUF_SIZE = 4096 if CI else 4096 * 128
  12. RUN_CNT = 4 if CI else 32
  13. cached_prgs = {}
  14. def helper_exec_op(device, outbuf, inbufs):
  15. if (device, len(inbufs)) not in cached_prgs:
  16. with Context(DEBUG=0):
  17. fst = [Tensor.randn(BUF_SIZE, dtype=dtypes.int).realize() for i in range(len(inbufs))]
  18. s = fst[0]
  19. for i in range(1, len(inbufs)): s = s.xor(fst[i])
  20. si = create_schedule([s.lazydata])[-1]
  21. prg = get_runner(device, si.ast)
  22. cached_prgs[(device, len(inbufs))] = prg
  23. return ExecItem(cached_prgs[(device, len(inbufs))], [outbuf] + inbufs)
  24. def helper_copy_op(device, dest, src):
  25. prg = BufferXfer(dest.nbytes, device, src.device)
  26. return ExecItem(prg, [dest, src])
  27. def helper_alloc_rawbuffer(device, fill=False):
  28. rawbuf = Buffer(device, BUF_SIZE, dtypes.int).ensure_allocated()
  29. if fill:
  30. with Context(DEBUG=0):
  31. data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
  32. rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
  33. return rawbuf
  34. def helper_run_jit(jis, bufs, out_buffers):
  35. for rawbuf in out_buffers:
  36. mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
  37. ctypes.memset(from_mv(mv), 0, len(mv))
  38. rawbuf.copyin(mv)
  39. for ei in jis: ei.run({}, jit=True)
  40. return [rawbuf.as_buffer() for rawbuf in bufs]
  41. def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
  42. reg_ji = []
  43. bufs = []
  44. out_buffers = set()
  45. for graph in graphs:
  46. for ji in graph:
  47. writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
  48. out_buffers.update(ji.bufs[:writable_buffers])
  49. bufs += ji.bufs
  50. reg_ji.append(ji)
  51. bufs = dedup(bufs)
  52. ground_thruth_bufs = helper_run_jit(reg_ji, bufs, out_buffers)
  53. ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
  54. # Build graphs
  55. gr_ji = [ExecItem(graph_impl(graph, [], {}), []) for graph in graphs]
  56. for _ in range(runs):
  57. test_bufs = helper_run_jit(gr_ji, bufs, out_buffers)
  58. test_bufs_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(test_bufs)]
  59. for i in range(len(ground_thruth_bufs)): np.testing.assert_equal(ground_truth_np[i], test_bufs_np[i])
  60. @unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
  61. @unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
  62. class TestGraph(unittest.TestCase):
  63. def test_order_2_writes_to_same_buf(self):
  64. d0 = Device.DEFAULT
  65. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(5)]
  66. graphs = [
  67. [helper_exec_op(d0, b0[0], [b0[1], b0[2]]), helper_exec_op(d0, b0[0], [b0[3], b0[4]])]
  68. ]
  69. helper_test_graphs(Device[d0].graph, graphs)
  70. def test_order_read_write_same_buf(self):
  71. d0 = Device.DEFAULT
  72. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(5)]
  73. graphs = [
  74. [helper_exec_op(d0, b0[0], [b0[1], b0[2]]), helper_exec_op(d0, b0[1], [b0[3], b0[4]])]
  75. ]
  76. helper_test_graphs(Device[d0].graph, graphs)
  77. def test_order_write_read_same_buf(self):
  78. d0 = Device.DEFAULT
  79. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(5)]
  80. graphs = [
  81. [helper_exec_op(d0, b0[0], [b0[1], b0[2]]), helper_exec_op(d0, b0[1], [b0[0], b0[4]])]
  82. ]
  83. helper_test_graphs(Device[d0].graph, graphs)
  84. @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
  85. def test_order_copy_writed(self):
  86. d0 = Device.DEFAULT
  87. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(4)]
  88. graphs = [
  89. [helper_exec_op(d0, b0[0], [b0[1], b0[2]]), helper_copy_op(d0, b0[3], b0[0])]
  90. ]
  91. helper_test_graphs(Device[d0].graph, graphs)
  92. @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
  93. def test_order_copy_then_read(self):
  94. d0 = Device.DEFAULT
  95. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(4)]
  96. graphs = [
  97. [helper_copy_op(d0, b0[1], b0[0]), helper_exec_op(d0, b0[3], [b0[1], b0[2]])]
  98. ]
  99. helper_test_graphs(Device[d0].graph, graphs)
  100. def test_read_write_several_graphs(self):
  101. d0 = Device.DEFAULT
  102. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
  103. graphs = [
  104. [helper_exec_op(d0, b0[3], [b0[1], b0[2]])],
  105. [helper_exec_op(d0, b0[4], [b0[1], b0[3]])],
  106. [helper_exec_op(d0, b0[5], [b0[4], b0[2]])]
  107. ]
  108. helper_test_graphs(Device[d0].graph, graphs)
  109. graphs = [
  110. [helper_exec_op(d0, b0[3], [b0[1], b0[2]]), helper_exec_op(d0, b0[4], [b0[1], b0[2]]), helper_exec_op(d0, b0[5], [b0[1], b0[2]])],
  111. [helper_exec_op(d0, b0[2], [b0[6], b0[7]])]
  112. ]
  113. helper_test_graphs(Device[d0].graph, graphs)
  114. @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
  115. def test_copies_2_devs(self):
  116. d0, d1 = Device.DEFAULT, f"{Device.DEFAULT}:1"
  117. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(3)]
  118. b1 = [helper_alloc_rawbuffer(d1, fill=True) for _ in range(1)]
  119. graphs = [
  120. [helper_copy_op(d0, b1[0], b0[0]), helper_exec_op(d0, b0[2], [b0[0], b0[1]])]
  121. ]
  122. helper_test_graphs(Device[d0].graph, graphs)
  123. @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
  124. def test_copies_after_graph_global(self):
  125. d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
  126. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
  127. b1 = [helper_alloc_rawbuffer(d1, fill=True) for _ in range(6)]
  128. b2 = [helper_alloc_rawbuffer(d2, fill=True) for _ in range(6)]
  129. b3 = [helper_alloc_rawbuffer(d3, fill=True) for _ in range(6)]
  130. graphs = [
  131. [helper_exec_op(d0, b0[2], [b0[0], b0[1]]), helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  132. helper_exec_op(d0, b0[5], [b0[0], b0[2]]), helper_exec_op(d0, b0[6], [b0[1], b0[2]]), helper_exec_op(d0, b0[7], [b0[0], b0[2]])],
  133. [helper_copy_op(d1, b0[2], b1[0])],
  134. [helper_exec_op(d0, b0[2], [b0[0], b0[1]]), helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  135. helper_exec_op(d0, b0[5], [b0[0], b0[2]]), helper_exec_op(d0, b0[6], [b0[1], b0[2]]), helper_exec_op(d0, b0[7], [b0[0], b0[2]])],
  136. [helper_copy_op(d3, b0[2], b3[0])],
  137. ]
  138. helper_test_graphs(Device[d0].graph, graphs)
  139. graphs = [
  140. [helper_exec_op(d0, b0[2], [b0[0], b0[1]]), helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  141. helper_exec_op(d0, b0[5], [b0[0], b0[2]]), helper_copy_op(d0, b2[0], b0[2]), helper_copy_op(d0, b2[1], b0[5]),
  142. helper_exec_op(d0, b0[7], [b0[0], b0[2]])],
  143. [helper_copy_op(d1, b0[2], b1[0])],
  144. [helper_exec_op(d0, b0[2], [b0[0], b0[1]])],
  145. [helper_copy_op(d3, b0[2], b3[0])],
  146. ]
  147. helper_test_graphs(Device[d0].graph, graphs)
  148. graphs = [
  149. [helper_exec_op(d0, b0[2], [b0[0], b0[1]]), helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  150. helper_exec_op(d0, b0[5], [b0[0], b0[2]]), helper_copy_op(d0, b2[0], b0[2]), helper_copy_op(d0, b2[1], b0[5]),
  151. helper_exec_op(d0, b0[7], [b0[0], b0[2]])],
  152. [helper_copy_op(d1, b0[5], b1[0])],
  153. [helper_copy_op(d3, b0[5], b3[0])],
  154. ]
  155. helper_test_graphs(Device[d0].graph, graphs)
  156. graphs = [
  157. [helper_copy_op(d1, b0[5], b1[0])],
  158. [helper_copy_op(d3, b0[5], b3[0])],
  159. ]
  160. helper_test_graphs(Device[d0].graph, graphs)
  161. @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
  162. def test_graph_after_copies_devs(self):
  163. d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
  164. b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
  165. b1 = [helper_alloc_rawbuffer(d1, fill=True) for _ in range(1)]
  166. b2 = [helper_alloc_rawbuffer(d2, fill=True) for _ in range(2)]
  167. b3 = [helper_alloc_rawbuffer(d3, fill=True) for _ in range(2)]
  168. graphs = [
  169. [helper_copy_op(d1, b0[0], b1[0])],
  170. [helper_copy_op(d2, b0[1], b2[0]), helper_copy_op(d3, b0[2], b3[0])],
  171. [helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  172. helper_exec_op(d0, b0[5], [b0[0], b0[2]])],
  173. ]
  174. helper_test_graphs(Device[d0].graph, graphs)
  175. graphs = [
  176. [helper_copy_op(d1, b0[0], b1[0])],
  177. [helper_exec_op(d0, b0[2], [b0[0], b0[1]])],
  178. [helper_copy_op(d2, b0[1], b2[0]), helper_copy_op(d3, b0[2], b3[0])],
  179. [helper_exec_op(d0, b0[3], [b0[0], b0[2]]), helper_exec_op(d0, b0[4], [b0[3], b0[2]]),
  180. helper_exec_op(d0, b0[5], [b0[0], b0[2]])],
  181. ]
  182. helper_test_graphs(Device[d0].graph, graphs)
  183. if __name__ == '__main__':
  184. unittest.main()