external_test_hsa_driver.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import ctypes, unittest
  2. from tinygrad.helpers import init_c_struct_t
  3. from tinygrad.device import Device, Buffer, BufferXfer
  4. from tinygrad.dtype import dtypes
  5. from tinygrad.runtime.support.hsa import AQLQueue
  6. from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
  7. from tinygrad.engine.realize import ExecItem
  8. def get_hsa_inc_prog(dev, inc=1):
  9. prg = f"""
  10. extern "C" __attribute__((global)) void test_inc(int* data0) {{
  11. data0[0] = (data0[0]+{inc});
  12. }}
  13. """
  14. return dev.runtime("test_inc", dev.compiler.compile(prg))
  15. def get_hsa_buffer_and_kernargs(dev):
  16. test_buf = Buffer(Device.DEFAULT, 1, dtypes.int)
  17. test_buf.copyin(memoryview(bytearray(4))) # zero mem
  18. assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec
  19. args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)]))
  20. kernargs = dev.alloc_kernargs(8)
  21. args_st = args_struct_t.from_address(kernargs)
  22. args_st.__setattr__('f0', test_buf._buf)
  23. dev.flush_hdp()
  24. return test_buf, kernargs
  25. @unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA")
  26. class TestHSADriver(unittest.TestCase):
  27. def test_hsa_simple_enqueue(self):
  28. dev = Device[Device.DEFAULT]
  29. queue = AQLQueue(dev, sz=256)
  30. clprg = get_hsa_inc_prog(dev, inc=1)
  31. test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
  32. queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs)
  33. queue.wait()
  34. assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?"
  35. del queue
  36. def test_hsa_ring_enqueue(self):
  37. dev = Device[Device.DEFAULT]
  38. queue_size = 256
  39. exec_cnt = int(queue_size * 1.5)
  40. queue = AQLQueue(dev, sz=queue_size)
  41. clprg_inc1 = get_hsa_inc_prog(dev, inc=1)
  42. clprg_inc2 = get_hsa_inc_prog(dev, inc=2)
  43. test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
  44. for _ in range(exec_cnt):
  45. queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs)
  46. for _ in range(exec_cnt):
  47. queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs)
  48. queue.wait()
  49. expected = exec_cnt + exec_cnt * 2
  50. assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
  51. del queue
  52. def test_hsa_blit_enqueue(self):
  53. dev = Device[Device.DEFAULT]
  54. queue_size = 256
  55. exec_cnt = 178
  56. queue = AQLQueue(dev, sz=queue_size)
  57. test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
  58. # Using VirtAQLQueue to blit them
  59. virt_queue_packets_cnt = 31
  60. virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt)
  61. clprogs = []
  62. sum_per_blit = 0
  63. for i in range(virt_queue_packets_cnt):
  64. sum_per_blit += i+1
  65. clprogs.append(get_hsa_inc_prog(dev, inc=i+1))
  66. for i in range(virt_queue_packets_cnt):
  67. virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs)
  68. for _ in range(exec_cnt):
  69. queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count)
  70. queue.wait()
  71. expected = exec_cnt * sum_per_blit
  72. assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
  73. del queue, clprogs
  74. def test_hsa_copies_sync(self):
  75. d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
  76. test_buf0 = Buffer(d0, 1, dtypes.int)
  77. test_buf1 = Buffer(d0, 1, dtypes.int)
  78. test_buf2 = Buffer(d1, 1, dtypes.int)
  79. test_buf0.copyin(memoryview(bytearray(1*4)))
  80. test_buf1.copyin(memoryview(bytearray(1*4)))
  81. test_buf2.copyin(memoryview(bytearray(1*4)))
  82. jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
  83. graph = HSAGraph(jit_cache, [], {})
  84. for i in range(10000):
  85. test_buf0.copyin(memoryview(bytearray(1*4)))
  86. test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
  87. graph([], {})
  88. assert test_buf0.as_buffer().cast('I')[0] == 4
  89. assert test_buf2.as_buffer().cast('I')[0] == 0
  90. if __name__ == '__main__':
  91. unittest.main()