| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import ctypes, unittest
- from tinygrad.helpers import init_c_struct_t
- from tinygrad.device import Device, Buffer, BufferXfer
- from tinygrad.dtype import dtypes
- from tinygrad.runtime.support.hsa import AQLQueue
- from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
- from tinygrad.engine.realize import ExecItem
- def get_hsa_inc_prog(dev, inc=1):
- prg = f"""
- extern "C" __attribute__((global)) void test_inc(int* data0) {{
- data0[0] = (data0[0]+{inc});
- }}
- """
- return dev.runtime("test_inc", dev.compiler.compile(prg))
- def get_hsa_buffer_and_kernargs(dev):
- test_buf = Buffer(Device.DEFAULT, 1, dtypes.int)
- test_buf.copyin(memoryview(bytearray(4))) # zero mem
- assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec
- args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)]))
- kernargs = dev.alloc_kernargs(8)
- args_st = args_struct_t.from_address(kernargs)
- args_st.__setattr__('f0', test_buf._buf)
- dev.flush_hdp()
- return test_buf, kernargs
- @unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA")
- class TestHSADriver(unittest.TestCase):
- def test_hsa_simple_enqueue(self):
- dev = Device[Device.DEFAULT]
- queue = AQLQueue(dev, sz=256)
- clprg = get_hsa_inc_prog(dev, inc=1)
- test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
- queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs)
- queue.wait()
- assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?"
- del queue
- def test_hsa_ring_enqueue(self):
- dev = Device[Device.DEFAULT]
- queue_size = 256
- exec_cnt = int(queue_size * 1.5)
- queue = AQLQueue(dev, sz=queue_size)
- clprg_inc1 = get_hsa_inc_prog(dev, inc=1)
- clprg_inc2 = get_hsa_inc_prog(dev, inc=2)
- test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
- for _ in range(exec_cnt):
- queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs)
- for _ in range(exec_cnt):
- queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs)
- queue.wait()
- expected = exec_cnt + exec_cnt * 2
- assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
- del queue
- def test_hsa_blit_enqueue(self):
- dev = Device[Device.DEFAULT]
- queue_size = 256
- exec_cnt = 178
- queue = AQLQueue(dev, sz=queue_size)
- test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
- # Using VirtAQLQueue to blit them
- virt_queue_packets_cnt = 31
- virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt)
- clprogs = []
- sum_per_blit = 0
- for i in range(virt_queue_packets_cnt):
- sum_per_blit += i+1
- clprogs.append(get_hsa_inc_prog(dev, inc=i+1))
- for i in range(virt_queue_packets_cnt):
- virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs)
- for _ in range(exec_cnt):
- queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count)
- queue.wait()
- expected = exec_cnt * sum_per_blit
- assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
- del queue, clprogs
- def test_hsa_copies_sync(self):
- d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
- test_buf0 = Buffer(d0, 1, dtypes.int)
- test_buf1 = Buffer(d0, 1, dtypes.int)
- test_buf2 = Buffer(d1, 1, dtypes.int)
- test_buf0.copyin(memoryview(bytearray(1*4)))
- test_buf1.copyin(memoryview(bytearray(1*4)))
- test_buf2.copyin(memoryview(bytearray(1*4)))
- jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
- graph = HSAGraph(jit_cache, [], {})
- for i in range(10000):
- test_buf0.copyin(memoryview(bytearray(1*4)))
- test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
- graph([], {})
- assert test_buf0.as_buffer().cast('I')[0] == 4
- assert test_buf2.as_buffer().cast('I')[0] == 0
- if __name__ == '__main__':
- unittest.main()
|