ops_disk.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import annotations
  2. import os, sys, mmap, _posixshmem, io, ctypes, ctypes.util, platform, contextlib
  3. from typing import Optional, Generator, Tuple, Callable, List
  4. from tinygrad.helpers import OSX, round_up
  5. from tinygrad.device import Compiled, Allocator
  6. import tinygrad.runtime.autogen.io_uring as io_uring
  7. import tinygrad.runtime.autogen.libc as libc
  8. class DiskBuffer:
  9. def __init__(self, device:DiskDevice, size:int, offset=0):
  10. self.device, self.size, self.offset = device, size, offset
  11. def __repr__(self): return f"<DiskBuffer size={self.size} offset={self.offset}>"
  12. def _buf(self) -> memoryview:
  13. assert self.device.mem is not None, "DiskBuffer wasn't opened"
  14. return memoryview(self.device.mem)[self.offset:self.offset+self.size]
  15. MAP_LOCKED, MAP_POPULATE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000)
  16. class DiskAllocator(Allocator):
  17. def __init__(self, device:DiskDevice): self.device = device
  18. def _alloc(self, size:int, options):
  19. self.device._might_open(size)
  20. return DiskBuffer(self.device, size)
  21. def _free(self, opaque, options): self.device._might_close()
  22. def as_buffer(self, src:DiskBuffer): return src._buf()
  23. def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
  24. def copyout(self, dest:memoryview, src:DiskBuffer):
  25. if OSX and hasattr(self.device, 'fd'):
  26. # OSX doesn't seem great at mmap, this is faster
  27. with io.FileIO(self.device.fd, "a+b", closefd=False) as fo:
  28. fo.seek(src.offset)
  29. fo.readinto(dest)
  30. else:
  31. dest[:] = src._buf()
  32. def _copyout_sharded(self, src:DiskBuffer, size:int, _get_free_buf:Callable, seg_len:int) -> Generator[Tuple[int, int, int, int], None, None]:
  33. assert hasattr(DiskDevice, 'io_uring'), "function requires io uring support"
  34. fd_offset = src.offset - (minor_offset := src.offset % mmap.PAGESIZE)
  35. processed_reqs_cnt, copied_in, next_read_offset, total_copy_size = 0, 0, 0, round_up(size + minor_offset, mmap.PAGESIZE)
  36. reqs: List[Tuple[int, int, int, int]] = []
  37. while next_read_offset < total_copy_size or len(reqs) != processed_reqs_cnt:
  38. if next_read_offset < total_copy_size and (copy_batch := _get_free_buf()) is not None:
  39. # Prepare sqe
  40. sqe_index = (tail:=DiskDevice.io_uring.sq.ktail[0]) & DiskDevice.io_uring.sq.kring_mask[0]
  41. sqe = DiskDevice.io_uring.sq.sqes[sqe_index]
  42. sqe.opcode, sqe.fd, sqe.off = io_uring.IORING_OP_READ, self.device.fd, fd_offset + next_read_offset
  43. sqe.addr, sqe.len, sqe.user_data = copy_batch[0], min(seg_len, total_copy_size - next_read_offset), len(reqs)
  44. # Send sqe
  45. DiskDevice.io_uring.sq.array[sqe_index] = sqe_index
  46. DiskDevice.io_uring.sq.ktail[0] = tail + 1
  47. libc.syscall(io_uring.NR_io_uring_enter, DiskDevice.io_uring.ring_fd, 1, 1, io_uring.IORING_ENTER_GETEVENTS)
  48. reqs.append((copy_batch, copied_in, minor_offset, real_copy_size:=min(sqe.len - minor_offset, size - copied_in)))
  49. next_read_offset += sqe.len
  50. copied_in += real_copy_size
  51. minor_offset = 0
  52. if (head:=DiskDevice.io_uring.cq.khead[0]) != DiskDevice.io_uring.cq.ktail[0]:
  53. cqe = DiskDevice.io_uring.cq.cqes[head & DiskDevice.io_uring.cq.kring_mask[0]]
  54. assert cqe.res >= 0, f"read from disk failed, err: {cqe.res}"
  55. yield reqs[cqe.user_data]
  56. DiskDevice.io_uring.cq.khead[0] = head + 1 # advance
  57. processed_reqs_cnt += 1
  58. def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
  59. class DiskDevice(Compiled):
  60. _tried_io_uring_init = False
  61. def __init__(self, device:str):
  62. if not DiskDevice._tried_io_uring_init: self._iouring_setup()
  63. self.size: Optional[int] = None
  64. self.count = 0
  65. super().__init__(device, DiskAllocator(self), None, None, None)
  66. def _might_open(self, size):
  67. self.count += 1
  68. assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
  69. if self.size is not None: return
  70. filename = self.dname[len("disk:"):]
  71. self.size = size
  72. if filename.startswith("shm:"):
  73. fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
  74. self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
  75. os.close(fd)
  76. else:
  77. try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|(0 if OSX else os.O_DIRECT))
  78. except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
  79. if os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
  80. self.mem = mmap.mmap(self.fd, self.size)
  81. if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
  82. with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
  83. def _might_close(self):
  84. self.count -= 1
  85. if self.count == 0:
  86. if hasattr(self, 'fd'): os.close(self.fd)
  87. self.size = None
  88. def _iouring_setup(self):
  89. DiskDevice._tried_io_uring_init = True
  90. if platform.system() != 'Linux' or hasattr(sys, "getandroidapilevel"): return
  91. fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
  92. if fd < 0: return
  93. sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
  94. cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
  95. mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
  96. sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
  97. mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)
  98. def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
  99. sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail), array=u32ptr(sq_ptr+p.sq_off.array),
  100. kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))
  101. cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
  102. kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))
  103. DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore