onnx.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from __future__ import annotations
  2. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  3. import importlib
  4. from functools import lru_cache
  5. import numpy as np
  6. from tinygrad import Tensor, dtypes, Device
  7. from tinygrad.tensor import _to_np_dtype
  8. from tinygrad.helpers import getenv, DEBUG, CI, OSX
  9. from tinygrad.dtype import ConstType
  10. from typing import List, Dict, Union
  11. from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
  12. try:
  13. from onnx.helper import tensor_dtype_to_np_dtype
  14. except ImportError:
  15. # for onnx < 1.13
  16. from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
  17. tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
  18. cache_misses = 0
  19. @lru_cache(None)
  20. def _cached_to_python_const(t:Tensor, tobytes): return t.data().tobytes() if tobytes else t.tolist()
  21. # Tensor -> python value cache for parameters
  22. def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
  23. if not isinstance(t, Tensor): return t
  24. global cache_misses
  25. ret = _cached_to_python_const(t, tobytes)
  26. if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
  27. print(f"Cache miss for {t}, {tobytes=}")
  28. cache_misses = info.misses
  29. return ret
  30. # copied from helpers.py
  31. def is_dtype_supported(dtype, device: str = Device.DEFAULT):
  32. if dtype == dtypes.bfloat16: return False
  33. if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
  34. if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
  35. if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
  36. return True
  37. # src: onnx/mapping.py
  38. # not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
  39. # NOTE: 17, 18, 19, 20 are float8, 10 is half
  40. DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64,
  41. 9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16,
  42. 17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float}
  43. # TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse
  44. onnx_ops = importlib.import_module('extra.onnx_ops')
  45. ONNXLIMIT = getenv("ONNXLIMIT", -1)
  46. def get_run_onnx(onnx_model: ModelProto):
  47. def type_parse(type_proto: TypeProto):
  48. ret = []
  49. while True:
  50. attr = type_proto.WhichOneof('value')
  51. if attr == 'tensor_type':
  52. if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
  53. elif not ret:
  54. return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
  55. else:
  56. ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
  57. return tuple(ret)
  58. elif attr == 'sequence_type':
  59. type_proto = getattr(type_proto, attr).elem_type
  60. ret.append(1)
  61. elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
  62. elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
  63. elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
  64. elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
  65. else: raise Exception(f"unknown attr: {attr}, {type_proto}")
  66. def buffer_parse(inp: TensorProto) -> Tensor:
  67. if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
  68. dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
  69. if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
  70. return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
  71. if len(inp.raw_data) > 0:
  72. return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy(),
  73. requires_grad=False).reshape(tuple(inp.dims))
  74. return Tensor(None, requires_grad=False)
  75. def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
  76. # TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
  77. if a.type == AttributeProto.FLOAT: return float(a.f)
  78. elif a.type == AttributeProto.INT: return int(a.i)
  79. elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
  80. elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
  81. elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
  82. elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
  83. elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
  84. elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
  85. else: raise Exception(f"can't parse {a.type} {a}")
  86. def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
  87. tensors: Dict[str, Tensor] = {}
  88. # get weights and biases
  89. for inp in onnx_model.graph.initializer:
  90. tensors[inp.name] = buffer_parse(inp)
  91. # preparse the attributes
  92. attribute_dict = {}
  93. domain = ""
  94. for num,n in enumerate(onnx_model.graph.node):
  95. attribute_dict[num] = attribute_to_dict(n.attribute)
  96. if n.domain: domain = n.domain
  97. onnx_model_version = onnx_model.opset_import[0].version
  98. def run_onnx(inputs={}, debug=0):
  99. debug = getenv("DEBUGONNX") or debug
  100. input_tensors: Dict[str,Tensor] = {}
  101. intermediate_tensors: Dict[str,Tensor] = {}
  102. output_tensor_names = [x.name for x in onnx_model.graph.output]
  103. # get inputs
  104. for inp in onnx_model.graph.input:
  105. if inp.name in tensors: continue
  106. shape = type_parse(inp.type)
  107. if inp.name in inputs:
  108. if isinstance(inputs[inp.name], Tensor):
  109. input_tensors[inp.name] = inputs[inp.name]
  110. elif isinstance(inputs[inp.name], list):
  111. input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
  112. elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
  113. input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
  114. else:
  115. input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
  116. if shape: # if only input_tensor is not variable type
  117. input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
  118. assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
  119. else:
  120. raise Exception(f"no data for {inp.name} with shape {shape}")
  121. def fetch_tensor(x: str):
  122. if x in tensors: return tensors[x]
  123. if x in intermediate_tensors: return intermediate_tensors[x]
  124. if x != "": return input_tensors[x]
  125. return None
  126. for num,n in enumerate(onnx_model.graph.node):
  127. inp: List[Tensor] = []
  128. if debug >= 3: print("inputs:")
  129. for x in n.input:
  130. t = fetch_tensor(x)
  131. if debug >= 3: print(f"\t{x} - {t}")
  132. inp.append(t)
  133. opt: Dict = attribute_dict[num]
  134. if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
  135. # NOTE some ops live here because they require access to some local variables
  136. # have to use n.output for cases when num_outputs is absent
  137. if n.op_type in onnx_ops.tensor_methods:
  138. ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
  139. elif n.op_type == "Split":
  140. axis = opt.get("axis", 0)
  141. split = None if len(inp) == 1 else to_python_const(inp[1])
  142. if split is None:
  143. split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
  144. for i in range(inp[0].shape[axis] % len(n.output)):
  145. split[i] += 1
  146. i, ret = 0, []
  147. arg = [None] * inp[0].ndim
  148. for s in split:
  149. arg[axis] = (i,i+s)
  150. ret.append(inp[0].shrink(arg=tuple(arg)))
  151. i = i+s
  152. ret = tuple(ret)
  153. # need to check onnx_model_version
  154. elif n.op_type == "Slice":
  155. if onnx_model_version < 10:
  156. axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
  157. else:
  158. starts, ends = inp[1:3]
  159. axes = list(range(inp[0].ndim)) if len(inp) <= 3 else to_python_const(inp[3].cast(dtypes.int32))
  160. steps = inp[4].cast(dtypes.int32).tolist() if len(inp) > 4 else [1]*inp[0].ndim
  161. starts, ends = to_python_const(starts), to_python_const(ends)
  162. arg = [(0,x,1) for x in inp[0].shape]
  163. for i, axis in enumerate(axes):
  164. axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
  165. if starts[i] < 0: starts[i] += inp[0].shape[axis]
  166. if ends[i] < 0: ends[i] += inp[0].shape[axis]
  167. starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
  168. if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
  169. arg[axis] = (starts[i], ends[i], steps[i])
  170. new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
  171. if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
  172. else: ret = inp[0][tuple([slice(s,e,st) for s,e,st in arg])]
  173. # need to call backward on intermediate_tensors
  174. elif n.op_type == "Gradient":
  175. assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
  176. y = opt["y"]
  177. intermediate_tensors[y].backward()
  178. ret = tuple([t.grad for t in inp])
  179. # onnx_ops.py
  180. elif hasattr(onnx_ops, n.op_type):
  181. fxn = getattr(onnx_ops, n.op_type)
  182. if isinstance(fxn, dict):
  183. for k in sorted(fxn.keys()):
  184. if k <= onnx_model_version:
  185. real_fxn = fxn[k]
  186. else:
  187. real_fxn = fxn
  188. ret = real_fxn(*inp, **opt)
  189. else:
  190. print("UNSUPPORTED", n.op_type, n.input, n.output)
  191. raise Exception(f"op_type {n.op_type} not supported")
  192. if not isinstance(ret, tuple): ret = (ret, )
  193. assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
  194. if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
  195. if debug >= 2: print("outputs:")
  196. for i in range(len(n.output)):
  197. if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
  198. intermediate_tensors[n.output[i]] = ret[i]
  199. if num == ONNXLIMIT:
  200. output_tensor_names = n.output
  201. break
  202. return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
  203. return run_onnx