compile2.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #!/usr/bin/env python3
  2. import os, sys, io, pathlib, json, struct
  3. import numpy as np
  4. sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
  5. if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
  6. if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
  7. if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
  8. OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
  9. import onnx
  10. from typing import Tuple, List, Optional, Dict, cast
  11. from extra.onnx import get_run_onnx
  12. from tinygrad import Tensor, Device, GlobalCounters, dtypes
  13. from tinygrad.dtype import ImageDType
  14. from tinygrad.device import Buffer
  15. from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
  16. from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
  17. from tinygrad.engine.schedule import ScheduleItem, create_schedule, memory_planner
  18. from tinygrad.ops import MetaOps
  19. from tinygrad.tensor import _to_np_dtype
  20. Device.DEFAULT = "GPU"
  21. def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
  22. Tensor.no_grad = True
  23. Tensor.training = False
  24. # load the model
  25. onnx_model = onnx.load(io.BytesIO(onnx_data))
  26. run_onnx = get_run_onnx(onnx_model)
  27. input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
  28. # run the model
  29. inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
  30. ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
  31. schedule = create_schedule([ret.lazydata])
  32. # filter schedule that don't depend on the inputs
  33. input_lb = [x.lazydata.base.buffer for x in inputs.values()]
  34. depends = set(input_lb)
  35. for si in schedule:
  36. if any(b in depends for b in si.inputs):
  37. for out in si.outputs: depends.add(out)
  38. # run all kernels that don't depend on the inputs
  39. # NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
  40. schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
  41. print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
  42. # confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
  43. assert all(si.ast.op is MetaOps.KERNEL or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
  44. return schedule, schedule_independent, inputs
  45. def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
  46. import onnx
  47. #import pyopencl as cl
  48. #from extra.thneed import Thneed
  49. import numpy as np
  50. onnx_model = onnx.load(io.BytesIO(onnx_data))
  51. input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
  52. Tensor.manual_seed(1337)
  53. new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
  54. new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()}
  55. if getenv("ORT"):
  56. # test with onnxruntime
  57. import onnxruntime as ort
  58. onnx_session = ort.InferenceSession(onnx_data)
  59. onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
  60. new_torch_out = onnx_output[0]
  61. print("got ort outputs")
  62. else:
  63. # test with torch
  64. from test.models.test_onnx import run_onnx_torch
  65. new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
  66. print("got torch outputs")
  67. # if you don't have a schedule
  68. if eis is None:
  69. run_onnx = get_run_onnx(onnx_model)
  70. new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
  71. np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
  72. print("classic self-test passed!")
  73. return
  74. # set inputs
  75. for k,v in inputs.items(): v.lazydata.base.realized.copyin(new_np_inputs[k].data)
  76. # run code (all buffers have been allocated)
  77. GlobalCounters.reset()
  78. output = eis[-1].bufs[0]
  79. for ei in eis: ei.run()
  80. new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
  81. np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
  82. print("semi-thneed self-test passed!")
  83. if __name__ == "__main__":
  84. onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes()
  85. # quick test for ONNX issues
  86. #thneed_test_onnx(onnx_data, None)
  87. #exit(0)
  88. schedule, schedule_independent, inputs = get_schedule(onnx_data)
  89. schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.KERNEL)
  90. print(f"{len(schedule_input)} inputs")
  91. run_schedule(schedule_independent)
  92. run_schedule(schedule_input)
  93. with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
  94. schedule = memory_planner(schedule)
  95. for si in schedule:
  96. for b in si.outputs:
  97. assert not b.is_allocated(), "output should not be allocated"
  98. image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
  99. print(f"**** compiling real kernels {image_count}/{len(schedule)} images ****")
  100. eis = list(tqdm(lower_schedule(schedule), total=len(schedule)))
  101. print("kernel count:", len(eis))
  102. assert len(eis) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
  103. # new simple thneed
  104. def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
  105. seen_buffers = set()
  106. input_buffers = [x.lazydata.buffer for x in inputs.values()]
  107. jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
  108. jdat["inputs"] = {k:to_ref(v.lazydata.buffer) for k,v in inputs.items()}
  109. jdat["outputs"] = [to_ref(eis[-1].bufs[0])]
  110. weights = []
  111. for i,ei in enumerate(eis):
  112. #print("***", i)
  113. for b in ei.bufs:
  114. needs_load = b.is_allocated() and b not in input_buffers
  115. #print(b, needs_load)
  116. if b in seen_buffers: continue
  117. seen_buffers.add(b)
  118. if isinstance(b.dtype, ImageDType):
  119. base_dtype = dtypes.float16 if b.dtype.fmt == 'e' else dtypes.float32
  120. row_pitch = (b.dtype.shape[0]*4*base_dtype.itemsize + 63)//64 * 64
  121. size = row_pitch * b.dtype.shape[1]
  122. jdat['objects'].append({
  123. "id": to_ref(b), "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
  124. "width": b.dtype.shape[0], "height": b.dtype.shape[1], "row_pitch": row_pitch, "float32": b.dtype.base == dtypes.float32,
  125. })
  126. if needs_load:
  127. t = Tensor.empty(b.dtype.shape, dtype=b.dtype)
  128. t.lazydata.buffer = b
  129. data = t.cast(dtypes.float32).pad(((0, row_pitch//(4*base_dtype.itemsize)-b.dtype.shape[0]), (0,0), (0,0))).contiguous().numpy()
  130. # NOTE: this cast must be done in numpy for platforms that don't support half
  131. if base_dtype == dtypes.float16: data = data.astype(np.float16)
  132. weights.append(data.tobytes())
  133. assert len(weights[-1]) == size, "wrong size buffer"
  134. else:
  135. jdat['objects'].append({
  136. "id": to_ref(b), "arg_type": b.dtype.name + "*", "needs_load": needs_load, "size": b.nbytes,
  137. })
  138. if needs_load:
  139. weights.append(b.as_buffer())
  140. assert len(weights[-1]) == b.nbytes, "wrong size buffer"
  141. saved_binaries = set()
  142. binaries = []
  143. GlobalCounters.reset()
  144. with Context(DEBUG=max(DEBUG.value, 2)):
  145. for ei in eis:
  146. prg = cast(CompiledRunner, ei.prg)
  147. assert len(prg.p.vars) == 0
  148. if prg.p.function_name not in saved_binaries:
  149. jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
  150. binaries.append(prg.lib)
  151. saved_binaries.add(prg.p.function_name)
  152. ei.run()
  153. jdat['kernels'].append({
  154. "name": prg.p.function_name,
  155. "work_dim": len(prg.p.global_size),
  156. "global_work_size": prg.p.global_size,
  157. "local_work_size": prg.p.local_size,
  158. "num_args": len(ei.bufs),
  159. "args": [to_ref(b) for b in ei.bufs],
  160. "arg_size": [8]*len(ei.bufs),
  161. })
  162. output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
  163. print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
  164. with open(output_fn, "wb") as f:
  165. j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
  166. f.write(struct.pack("I", len(j)))
  167. f.write(j)
  168. for w in weights: f.write(w)
  169. for b in binaries: f.write(b)
  170. print("saved", f.tell(), "bytes")
  171. FLOAT16 = getenv("FLOAT16", 0)
  172. if FLOAT16 == 0:
  173. try:
  174. test_vs_onnx(onnx_data, eis, inputs)
  175. except ModuleNotFoundError as e:
  176. print(f"TEST NOT HAPPENING {e}")