device.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. from __future__ import annotations
  2. import multiprocessing
  3. from dataclasses import dataclass
  4. from collections import defaultdict
  5. from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
  6. import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
  7. from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
  8. from tinygrad.dtype import DType, ImageDType
  9. from tinygrad.renderer import Renderer
  10. # **************** Device ****************
  11. class _Device:
  12. def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
  13. @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
  14. def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
  15. # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
  16. def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
  17. def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
  18. @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
  19. def __get_canonicalized_item(self, ix:str) -> Compiled:
  20. assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
  21. f"can only open device {ix} from parent, not {cpn}"
  22. x = ix.split(":")[0].upper()
  23. ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
  24. if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
  25. return ret
  26. @functools.cached_property
  27. def DEFAULT(self) -> str:
  28. device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
  29. if device_from_env: return device_from_env
  30. for device in ["METAL", "AMD", "NV", "CUDA", "GPU", "CLANG", "LLVM"]:
  31. try:
  32. if self[device]:
  33. os.environ[device] = "1" # we set this in environment for spawned children
  34. return device
  35. except Exception: pass
  36. raise RuntimeError("no usable devices")
  37. Device = _Device()
  38. # **************** Buffer + Allocators ****************
  39. @dataclass(frozen=True, eq=True)
  40. class BufferOptions:
  41. image: Optional[ImageDType] = None
  42. uncached: bool = False
  43. cpu_access: bool = False
  44. host: bool = False
  45. nolru: bool = False
  46. class Buffer:
  47. def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
  48. initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
  49. assert isinstance(dtype, DType)
  50. if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
  51. self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
  52. if base is None:
  53. assert offset == 0, "base buffers can't have offset"
  54. self._base = None
  55. self._lb_refcount = lb_refcount
  56. if opaque is not None: self.allocate(opaque)
  57. if initial_value is not None:
  58. self.allocate()
  59. self.copyin(memoryview(initial_value))
  60. else:
  61. assert base._base is None, "base can't have a base"
  62. assert device == base.device, "base must have the same device"
  63. self._base = base
  64. if preallocate: self.allocate()
  65. @property
  66. def base(self) -> Buffer: return self._base if self._base is not None else self
  67. @property
  68. def lb_refcount(self): return self.base._lb_refcount
  69. def ref(self, cnt): self.base._lb_refcount += cnt
  70. def is_allocated(self) -> bool: return hasattr(self, '_buf')
  71. def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
  72. def allocate(self, opaque=None) -> Buffer:
  73. assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
  74. self.allocator = Device[self.device].allocator
  75. if self._base is not None:
  76. self._base.ensure_allocated()
  77. assert hasattr(self.allocator, "offset"), "offset function required for view"
  78. self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
  79. else:
  80. self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
  81. if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
  82. return self
  83. def __reduce__(self):
  84. buf = None
  85. if self._base is not None:
  86. return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
  87. if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
  88. if self.is_allocated():
  89. buf = bytearray(self.nbytes)
  90. self.copyout(memoryview(buf))
  91. return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
  92. @property
  93. def nbytes(self): return self.size*self.dtype.itemsize
  94. def __del__(self):
  95. if not hasattr(self, '_buf'): return
  96. if self._base is None:
  97. if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
  98. self.allocator.free(self._buf, self.nbytes, self.options)
  99. def __repr__(self):
  100. return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
  101. (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
  102. (">" if self.options is None else f" {self.options=}>")
  103. def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
  104. # zero copy with as_buffer (disabled by default due to use after free)
  105. if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
  106. assert not force_zero_copy, "force zero copy was passed, but copy is required"
  107. return self.copyout(memoryview(bytearray(self.nbytes)))
  108. def copyin(self, mv:memoryview):
  109. mv = flat_mv(mv)
  110. assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
  111. assert self.is_allocated(), "can't copyin to unallocated buffer"
  112. self.allocator.copyin(self._buf, mv)
  113. return self
  114. def copyout(self, mv:memoryview) -> memoryview:
  115. mv = flat_mv(mv)
  116. assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
  117. assert self.is_allocated(), "can't copyout unallocated buffer"
  118. self.allocator.copyout(mv, self._buf)
  119. return mv
  120. def view(self, size:int, dtype:DType, offset:int) -> Buffer:
  121. assert offset < self.nbytes, "offset must be less than nbytes"
  122. if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
  123. return Buffer(self.device, size, dtype, base=self, offset=offset)
  124. # TODO: size, dest, src are the same type. can we enforce this?
  125. class Allocator:
  126. def alloc(self, size:int, options:Optional[BufferOptions]=None):
  127. assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
  128. return self._alloc(size, options if options is not None else BufferOptions())
  129. def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
  130. def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
  131. self._free(opaque, options if options is not None else BufferOptions())
  132. def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
  133. def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
  134. def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
  135. class LRUAllocator(Allocator): # pylint: disable=abstract-method
  136. """
  137. The LRU Allocator is responsible for caching buffers.
  138. It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
  139. """
  140. def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
  141. def alloc(self, size:int, options:Optional[BufferOptions]=None):
  142. if len(c := self.cache[(size, options)]): return c.pop()
  143. try: return super().alloc(size, options)
  144. except (RuntimeError, MemoryError):
  145. self.free_cache()
  146. return super().alloc(size, options)
  147. def free_cache(self):
  148. for (sz,options),opaques in self.cache.items():
  149. for opaque in opaques: super().free(opaque, sz, options)
  150. opaques.clear()
  151. def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
  152. if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
  153. else: super().free(opaque, size, options)
  154. class _MallocAllocator(LRUAllocator):
  155. def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
  156. def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
  157. def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
  158. def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
  159. def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
  160. MallocAllocator = _MallocAllocator()
  161. # **************** for Compiled Devices ****************
  162. class CompileError(Exception): pass
  163. class Compiler:
  164. def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
  165. def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
  166. def compile_cached(self, src:str) -> bytes:
  167. if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
  168. assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
  169. lib = self.compile(src)
  170. if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
  171. return lib
  172. class Compiled:
  173. def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
  174. self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
  175. self.renderer = renderer or Renderer()
  176. def synchronize(self):
  177. """
  178. Synchronize all pending operations on the device.
  179. This method ensures that all previously queued operations on the device have been completed before proceeding.
  180. """
  181. # override this in your device implementation
  182. # **************** for HCQ Compatible Devices ****************
  183. def hcq_command(func):
  184. """
  185. Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
  186. For example:
  187. ```python
  188. @hcq_command
  189. def command_method(self, ...): ...
  190. ```
  191. """
  192. def __wrapper(self, *args, **kwargs):
  193. self.cmds_offset.append(len(self.q))
  194. func(self, *args, **kwargs)
  195. self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
  196. self.cmds_meta.append(func.__name__)
  197. return self
  198. return __wrapper
  199. class HWCommandQueue:
  200. """
  201. A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
  202. Both compute and copy queues should have the following commands implemented.
  203. """
  204. def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
  205. def __len__(self): return len(self.cmds_offset)
  206. def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
  207. @hcq_command
  208. def signal(self, signal:Any, value:int):
  209. """
  210. Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
  211. Args:
  212. signal: The signal to set
  213. value: The value to set the signal to
  214. """
  215. self._signal(signal, value)
  216. def _signal(self, signal:Any, value:int): raise NotImplementedError("backend should overload this function")
  217. @hcq_command
  218. def wait(self, signal:Any, value:int):
  219. """
  220. Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
  221. Args:
  222. signal: The signal to wait on
  223. value: The value to wait for
  224. """
  225. self._wait(signal, value)
  226. def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
  227. @hcq_command
  228. def timestamp(self, signal:Any):
  229. """
  230. Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
  231. Args:
  232. signal: The signal to store the timestamp
  233. """
  234. self._timestamp(signal)
  235. def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
  236. def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
  237. """
  238. Updates a previously queued signal command.
  239. Args:
  240. cmd_idx: Index of the signal command to update
  241. signal: New signal to set (if None, keeps the original)
  242. value: New value to set (if None, keeps the original)
  243. """
  244. if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
  245. self._update_signal(cmd_idx, signal, value)
  246. return self
  247. def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
  248. def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
  249. """
  250. Updates a previously queued wait command.
  251. Args:
  252. cmd_idx: Index of the wait command to update
  253. signal: New signal to wait on (if None, keeps the original)
  254. value: New value to wait for (if None, keeps the original)
  255. """
  256. if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
  257. self._update_wait(cmd_idx, signal, value)
  258. return self
  259. def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
  260. def submit(self, device:HCQCompatCompiled):
  261. """
  262. Submits the command queue to a specific device for execution.
  263. Args:
  264. device: The device to submit the queue to
  265. """
  266. self._submit(device)
  267. return self
  268. def _submit(self, device:HCQCompatCompiled): raise NotImplementedError("backend should overload this function")
  269. class HWComputeQueue(HWCommandQueue):
  270. @hcq_command
  271. def memory_barrier(self):
  272. """
  273. Enqueues a memory barrier command to ensure memory coherence between agents.
  274. """
  275. self._memory_barrier()
  276. def _memory_barrier(self): pass
  277. @hcq_command
  278. def exec(self, prg:HCQCompatProgram, kernargs:int, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
  279. """
  280. Enqueues an execution command for a kernel program.
  281. Args:
  282. prg: The program to execute
  283. kernargs: The pointer to kernel arguments
  284. global_size: The global work size
  285. local_size: The local work size
  286. """
  287. self._exec(prg, kernargs, global_size, local_size)
  288. def _exec(self, prg, kernargs, global_size, local_size): raise NotImplementedError("backend should overload this function")
  289. def update_exec(self, cmd_idx:int, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
  290. """
  291. Updates a previously queued execution command.
  292. Args:
  293. cmd_idx: Index of the execution command to update
  294. global_size: New global work size
  295. local_size: New local work size
  296. """
  297. if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
  298. self._update_exec(cmd_idx, global_size, local_size)
  299. return self
  300. def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
  301. class HWCopyQueue(HWCommandQueue):
  302. @hcq_command
  303. def copy(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, copy_size:int):
  304. """
  305. Enqueues a copy command to transfer data.
  306. Args:
  307. dest: The destination of the copy
  308. src: The source of the copy
  309. copy_size: The size of data to copy
  310. """
  311. self._copy(dest, src, copy_size)
  312. def _copy(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, copy_size:int): raise NotImplementedError("backend should overload this function")
  313. def update_copy(self, cmd_idx:int, dest:Optional[HCQCompatAllocRes]=None, src:Optional[HCQCompatAllocRes]=None):
  314. """
  315. Updates a previously queued copy command.
  316. Args:
  317. cmd_idx: Index of the copy command to update
  318. dest: New destination of the copy (if None, keeps the original)
  319. src: New source of the copy (if None, keeps the original)
  320. """
  321. if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
  322. self._update_copy(cmd_idx, dest, src)
  323. return self
  324. def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
  325. @contextlib.contextmanager
  326. def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
  327. st, en = (dev._alloc_signal(), dev._alloc_signal()) if enabled else (None, None)
  328. if enabled and queue is not None: queue.timestamp(st)
  329. elif enabled: queue_type().timestamp(st).submit(dev)
  330. try: yield (st, en)
  331. finally:
  332. if enabled and queue is not None: queue.timestamp(en)
  333. elif enabled: queue_type().timestamp(en).submit(dev)
  334. if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
  335. class HCQCompatProgram:
  336. def __init__(self, kernargs_alloc_size:int, kernargs_args_offset:int=0):
  337. self.kernargs_alloc_size, self.kernargs_args_offset = kernargs_alloc_size, kernargs_args_offset
  338. def fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
  339. class HCQCompatCompiled(Compiled):
  340. """
  341. A base class for devices compatible with the HCQ (Hardware Command Queue) API.
  342. """
  343. def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime,
  344. comp_queue_t:Type[HWComputeQueue], copy_queue_t:Type[HWCopyQueue], timeline_signals:Tuple[Any, Any]):
  345. self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
  346. self.timeline_value:int = 1
  347. self.timeline_signal, self._shadow_timeline_signal = timeline_signals
  348. self.sig_prof_records:List[Tuple[Any, Any, str, bool]] = []
  349. self.raw_prof_records:List[Tuple[int, int, str, bool]] = []
  350. if PROFILE: self._prof_setup()
  351. from tinygrad.runtime.graph.hcq import HCQGraph
  352. super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
  353. @classmethod
  354. def _read_signal(cls, signal:Any) -> int:
  355. """
  356. Read a value for a signal.
  357. """
  358. raise NotImplementedError("_read_signal needs to be implemented")
  359. @classmethod
  360. def _read_timestamp(cls, signal:Any) -> int:
  361. """
  362. Read a timestamp for a signal.
  363. """
  364. raise NotImplementedError("_read_timestamp needs to be implemented")
  365. @classmethod
  366. def _set_signal(cls, signal:Any, value:int):
  367. """
  368. Set a value for a signal.
  369. """
  370. raise NotImplementedError("_set_signal needs to be implemented")
  371. @classmethod
  372. def _alloc_signal(cls, value:int = 0, **kwargs) -> Any:
  373. """
  374. Allocate a new signal.
  375. """
  376. raise NotImplementedError("_alloc_signal needs to be implemented")
  377. @classmethod
  378. def _free_signal(cls, signal:Any):
  379. """
  380. Free a signal.
  381. """
  382. raise NotImplementedError("_free_signal needs to be implemented")
  383. @classmethod
  384. def _wait_signal(cls, signal:Any, value:int = 0, timeout:int = 10000):
  385. """
  386. Wait for a signal to reach a specific value. Signals
  387. """
  388. raise NotImplementedError("_wait_signal needs to be implemented")
  389. def _gpu2cpu_time(self, gpu_time:int, is_copy:bool) -> float:
  390. """
  391. Convert GPU time to CPU time. `is_copy` flag indicating if this is a copy queue.
  392. """
  393. raise NotImplementedError("_gpu2cpu_time needs to be implemented")
  394. def _prof_setup(self):
  395. if not hasattr(self, 'profile_logger'): atexit.register(self._prof_finalize)
  396. self.profile_logger = ProfileLogger()
  397. def _sync_queue(q_t):
  398. q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
  399. self.timeline_value += 1
  400. cpu_start_time = time.perf_counter_ns() / 1e3
  401. self._wait_signal(self.timeline_signal, self.timeline_value - 1)
  402. return cpu_start_time, self._read_timestamp(self.timeline_signal)
  403. self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
  404. self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
  405. def _prof_process_events(self):
  406. self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
  407. for st, en, _, _ in self.sig_prof_records: map(self._alloc_signal, [st, en])
  408. self.sig_prof_records = []
  409. def _prof_finalize(self):
  410. for st, en, name, is_cp in self.raw_prof_records:
  411. self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
  412. del self.profile_logger
  413. def _wrap_timeline_signal(self):
  414. self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
  415. self._set_signal(self.timeline_signal, 0)
  416. cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
  417. # Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
  418. class HCQCompatAllocRes(Protocol): va_addr:int; size:int # noqa: E702
  419. class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
  420. """
  421. A base allocator class compatible with the HCQ (Hardware Command Queue) API.
  422. This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
  423. """
  424. def __init__(self, device:HCQCompatCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
  425. self.device:Any = device
  426. self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
  427. self.b_timeline, self.b_next = [0] * len(self.b), 0
  428. super().__init__()
  429. def _alloc(self, size:int, options:BufferOptions) -> HCQCompatAllocRes: raise NotImplementedError("need hcq compat alloc")
  430. def copyin(self, dest:HCQCompatAllocRes, src:memoryview):
  431. with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
  432. for i in range(0, src.nbytes, self.b[0].size):
  433. self.b_next = (self.b_next + 1) % len(self.b)
  434. self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
  435. ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
  436. self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
  437. .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
  438. .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
  439. self.b_timeline[self.b_next] = self.device.timeline_value
  440. self.device.timeline_value += 1
  441. def copy_from_disk(self, dest:HCQCompatAllocRes, src, size):
  442. def _get_temp_buf():
  443. # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
  444. if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
  445. self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
  446. return (self.b[self.b_next].va_addr, self.b_next)
  447. return None
  448. with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
  449. for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
  450. self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
  451. .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
  452. .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
  453. self.b_timeline[batch_info[1]] = self.device.timeline_value
  454. self.device.timeline_value += 1
  455. def copyout(self, dest:memoryview, src:HCQCompatAllocRes):
  456. self.device.synchronize()
  457. with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
  458. for i in range(0, dest.nbytes, self.b[0].size):
  459. self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
  460. .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
  461. .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
  462. self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
  463. self.device.timeline_value += 1
  464. ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
  465. def transfer(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, sz:int, src_dev, dest_dev):
  466. src_dev._gpu_map(dest)
  467. with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
  468. src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
  469. .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
  470. .copy(dest.va_addr, src.va_addr, sz) \
  471. .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
  472. src_dev.timeline_value += 1
  473. if src_dev != dest_dev:
  474. dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
  475. .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
  476. .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
  477. dest_dev.timeline_value += 1
  478. def offset(self, buf, size:int, offset:int) -> HCQCompatAllocRes:
  479. return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
  480. **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)