1
0

ane.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #!/usr/bin/env python3
  2. from pathlib import Path
  3. from ctypes import *
  4. import json
  5. import collections
  6. import numpy as np
  7. import faulthandler
  8. import struct
  9. faulthandler.enable()
  10. basedir = Path(__file__).resolve().parent
  11. libane = None
  12. aneregs = None
  13. def init_libane():
  14. global libane, aneregs
  15. libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
  16. libane.ANE_Compile.argtypes = [c_char_p, c_int]
  17. libane.ANE_Compile.restype = c_void_p
  18. libane.ANE_TensorCreate.restype = c_void_p
  19. libane.ANE_TensorData.argtypes = [c_void_p]
  20. libane.ANE_TensorData.restype = POINTER(c_uint16)
  21. libane.ANE_Run.argtypes = [c_void_p]*4
  22. libane.ANE_Run.restype = c_int
  23. #libane.ANE_RegDebug.restype = c_char_p
  24. with open(basedir / "aneregs.json") as f:
  25. aneregs = json.load(f)
  26. ANE_Struct = [
  27. # aneTD.Header
  28. ("u32", 0x1C, "NextCommandOffset"),
  29. # KernelDMASrc @ section @ 0x2C len 0xF4
  30. # reloc 0x2c-0x34?? = weights
  31. # u32[16] 0x34-0x74 = 0x80 | 1 if used
  32. # u32[16] 0x74-0xB4 = <channel data offset>
  33. # u32[16] 0xB4-0xF4 = <channel data length>
  34. # Common @ section @ 0x128 len 0x3C (conv)
  35. ("u16", 0x128, "InputWidth"),
  36. ("u16", 0x12A, "InputHeight"),
  37. ("u16", 0x12C, "InputDepth"),
  38. ("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
  39. # UInt8 = 0, Int8 = 1, Float16 = 2
  40. ("u32", 0x134, "InputChannels"),
  41. ("u32", 0x138, "OutputChannels"),
  42. ("u16", 0x13C, "OutputWidth"),
  43. ("u16", 0x13E, "OutputHeight"),
  44. ("u16", 0x140, "OutputDepth"),
  45. ("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
  46. ("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
  47. ("u16", 0x14C, "BatchSize"),
  48. # TileDMASrc @ section @ 0x16C len 0x6C (input)
  49. # reloc 0x16c-0x174 = image
  50. ("u32", 0x178, "InputRowStride"),
  51. ("u32", 0x17C, "InputPlaneStride"),
  52. ("u32", 0x180, "InputDepthStride"),
  53. ("u32", 0x184, "InputBatchStride"),
  54. ("u8", 0x1A7, "InputInterleave"),
  55. # L2 @ section @ 0x1E0 len 0x44
  56. # [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
  57. # [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
  58. # [0x21c, 0x220, 0x224] = engines for outconv?
  59. # NE @ section @ 0x22c len 0xC (scaling)
  60. ("u16", 0x230, "BiasScalar"),
  61. ("u16", 0x232, "ScaleScalar"),
  62. # section @ 0x240 len 0x10
  63. ("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
  64. ("u32", 0x250, "PostScale"),
  65. # TileDMADst @ section @ 0x258 len 0x18
  66. # HandleTileDmaDstConfig
  67. # 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
  68. # (GetCacheHintRegisterValue & 0xf) << 6;
  69. ("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
  70. # 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
  71. ("u32", 0x260, "OutputRowStride"),
  72. ("u32", 0x264, "OutputPlaneStride"),
  73. ("u32", 0x268, "OutputDepthStride"),
  74. ("u32", 0x26C, "OutputBatchStride"),
  75. # 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
  76. # uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
  77. # (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
  78. ("u8", 0x273, "OutputInterleave"), # i also have this at 0x211?
  79. ]
  80. ANE_Struct_Dict = {}
  81. for typ, num, nam in ANE_Struct:
  82. styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
  83. ANE_Struct_Dict[nam] = (styp, num)
  84. class ANETensor:
  85. def __init__(self, *shape):
  86. self.shape = shape
  87. self.dtype = np.float16
  88. self.sz = int(np.prod(shape))
  89. assert(self.sz <= 0x4000)
  90. self.tt = libane.ANE_TensorCreate(self.sz, 1)
  91. assert(self.tt is not None)
  92. def data(self):
  93. data = libane.ANE_TensorData(self.tt)
  94. assert(data is not None)
  95. #print(hex(addressof(data.contents)))
  96. buf = np.ctypeslib.as_array(data, shape=(self.sz,))
  97. ret = np.frombuffer(buf, dtype=self.dtype)
  98. #print(ret.data)
  99. return ret
  100. class ANE:
  101. def __init__(self):
  102. init_libane()
  103. libane.ANE_Open()
  104. def compile(self, dat):
  105. ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
  106. assert(ret is not None)
  107. return ret
  108. def run(self, prog, tin, tout, tweights=None):
  109. libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0)
  110. def tensor(self, shape):
  111. return ANETensor(shape)
  112. def unpack(self, dat):
  113. dat = struct.unpack("Q"*(len(dat)//8), dat)
  114. ret = {}
  115. for k,v in aneregs:
  116. by,bi,sz = v
  117. bi += (by%8)*8
  118. by //= 8
  119. rv = (dat[by] >> bi) & ((1 << sz)-1)
  120. ret[k] = rv
  121. return ret
  122. def pack(self, pk, dat):
  123. dat = list(struct.unpack("Q"*(len(dat)//8), dat))
  124. for k,v in aneregs:
  125. by,bi,sz = v
  126. bi += (by%8)*8
  127. by //= 8
  128. dat[by] &= ~(((1 << sz)-1) << bi)
  129. dat[by] |= pk[k] << bi
  130. dat = struct.pack("Q"*len(dat), *dat)
  131. return dat
  132. def debug(self, dat, mems=0):
  133. add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c]
  134. lens = [244, 60, 108, 68, 12, 16, 24]
  135. ptr = 0x2b
  136. ddat = dat[0:0x28]
  137. for a, pm in zip(add, lens):
  138. #assert pm == dat[ptr]
  139. ddat += b"\x00" * (a-len(ddat))
  140. ddat += dat[ptr+1:ptr+1+pm+4]
  141. ptr += pm+8
  142. ddat += b"\x00" * 0x100
  143. ret = collections.OrderedDict()
  144. for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"):
  145. lnn = ln.split(" = ")
  146. if len(lnn) == 2:
  147. ret[lnn[0]] = int(lnn[1])
  148. return ret
  149. def filln(self, dat, nvdict, base=0x4000):
  150. for n,v in nvdict.items():
  151. styp, num = ANE_Struct_Dict[n]
  152. dat = self.fill(dat, [num], styp, v)
  153. return dat
  154. def fill(self, dat, addrs, type, val, base=0x4000):
  155. x = struct.pack(type, val)
  156. for a in addrs:
  157. dat[base+a:base+a+len(x)] = x
  158. return dat
  159. if __name__ == "__main__":
  160. ane = ANE()
  161. tin = ANETensor(16)
  162. tout = ANETensor(16)
  163. tind = tin.data()
  164. toutd = tout.data()
  165. tind[0:4] = [-1,1,-2,2]
  166. print("** before **")
  167. print(tind)
  168. print(toutd)
  169. dat = open("../ops/relu.hwx", "rb").read()
  170. md = dat[0x4000:0x4300]
  171. dd = ane.unpack(md)
  172. mdf = ane.pack(dd, md)
  173. assert(md == mdf)
  174. comp = ane.compile(dat)
  175. ret = ane.run(comp, tin, tout)
  176. print("** after **")
  177. print(tind)
  178. print(toutd)