driver.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import ctypes, struct, os, functools
  2. from typing import Union
  3. from dataclasses import dataclass
  4. from tinygrad.helpers import round_up, to_mv
  5. class VirtFileDesc:
  6. def __init__(self, fd): self.fd, self.off = fd, 0
  7. def read(self, fd, buf, sz): raise NotImplementedError()
  8. def ioctl(self, fd, req, argp): raise NotImplementedError()
  9. def mmap(self, st, sz, prot, flags, fd, off): raise NotImplementedError()
  10. def write(self, fd, buf, sz): raise NotImplementedError()
  11. def lseek(self, fd, off, whence): raise NotImplementedError()
  12. def fstat(self, fd, buf): raise NotImplementedError()
  13. def getdents(self, fd, buf, sz): return -1
  14. def close(self, fd): return 0
  15. class TextFileDesc(VirtFileDesc):
  16. def __init__(self, fd, text):
  17. super().__init__(fd)
  18. self.content = ctypes.create_string_buffer(text.encode())
  19. self.sz = len(self.content) - 1
  20. def ioctl(self, fd, req, argp): return 0
  21. def write(self, fd, buf, sz): return -1
  22. def read(self, fd, buf, sz):
  23. ctypes.memmove(buf, ctypes.addressof(self.content) + self.off, rdsz:=min(sz, self.sz - self.off))
  24. self.off += rdsz
  25. return rdsz
  26. def lseek(self, fd, off, whence):
  27. if whence == os.SEEK_SET: self.off = off
  28. elif whence == os.SEEK_CUR: self.off += off
  29. elif whence == os.SEEK_END: self.off = self.sz + off
  30. else: return -1
  31. return 0
  32. def fstat(self, fd, buf):
  33. ctypes.memmove(buf, VirtFile.build_fstat(st_size=self.sz), 88)
  34. return 0
  35. class DirFileDesc(VirtFileDesc):
  36. def __init__(self, fd, child_names):
  37. super().__init__(fd)
  38. child_names = ['.', '..'] + child_names
  39. tmp = b''
  40. for ino, name in enumerate(child_names):
  41. tmp += VirtFile.build_dirent(ino + 1, 0, name)
  42. self.content = ctypes.create_string_buffer(tmp)
  43. self.sz = len(self.content) - 1
  44. def ioctl(self, fd, req, argp): return 0
  45. def write(self, fd, buf, sz): return -1
  46. def read(self, fd, buf, sz): return -1
  47. def lseek(self, fd, off, whence):
  48. if whence == os.SEEK_SET: self.off = off
  49. elif whence == os.SEEK_CUR: self.off += off
  50. elif whence == os.SEEK_END: self.off = self.sz + off
  51. else: return -1
  52. return 0
  53. def getdents(self, fd, buf, sz):
  54. if self.sz == self.off: return 0
  55. if sz < self.sz: return -1
  56. ctypes.memmove(buf, ctypes.addressof(self.content) + self.off, self.sz)
  57. self.off = self.sz
  58. return self.sz
  59. def fstat(self, fd, buf):
  60. ctypes.memmove(buf, VirtFile.build_fstat(st_mode=0o40755), 96)
  61. return 0
  62. @dataclass(frozen=True)
  63. class VirtFile():
  64. path: str
  65. fdcls: Union[VirtFileDesc, functools.partial[VirtFileDesc]]
  66. @staticmethod
  67. def build_fstat(st_dev=0x20, st_ino=0x100000, st_mode=0o100777, st_nlink=1, st_uid=0, st_gid=0, st_rdev=0, st_size=0,
  68. st_blksize=4096, st_blocks=0, st_atime=0, st_mtime=0, st_ctime=0):
  69. assert (ssz:=struct.calcsize(fmt_string:='QQQIIIQQiQqqq')) == 96, f"{ssz} != 96"
  70. return struct.pack(fmt_string, st_dev, st_ino, st_nlink, st_mode, st_uid, st_gid,
  71. st_rdev, st_size, st_blksize, st_blocks, st_atime, st_mtime, st_ctime)
  72. @staticmethod
  73. def build_dirent(d_ino, d_off, d_name, d_type=None):
  74. # Start with packing inode number, offset, and record length
  75. d_reclen = round_up(19 + len(d_name) + 1, 8)
  76. packed_data = struct.pack('QQHc', d_ino, d_off, d_reclen, b'\x04')
  77. d_name_bytes = d_name.encode()
  78. return packed_data + d_name_bytes + b'\x00' + b'\x00' * (d_reclen - (19 + len(d_name) + 1))
  79. class VirtDriver:
  80. def __init__(self):
  81. self.tracked_files = []
  82. self.tracked_addresses = []
  83. def track_address(self, staddr, enaddr, rcb, wcb): self.tracked_addresses.append((staddr, enaddr, rcb, wcb))
  84. def open(self, name, flags, mode, fdcls): raise NotImplementedError()