onnx_ops.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. import functools, io, math
  2. from typing import Union, Tuple, Optional, List, Any
  3. from tinygrad.tensor import Tensor, _broadcast_shape
  4. from tinygrad.dtype import ImageDType, dtypes
  5. from tinygrad.helpers import prod, flatten
  6. from extra.onnx import DTYPE_MAP, to_python_const
  7. import numpy as np
  8. tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
  9. "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
  10. "Elu", "Celu", "Xor", "Round"}
  11. # **************** Free Ops ****************
  12. def Identity(x: Tensor): return x
  13. # TODO: fix buffer_parse
  14. def Add(x: Tensor, other: Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
  15. def Sub(x: Union[Tensor, Any], other: Tensor): return x - other # some test has input as int
  16. def Less(x:Tensor,y:Tensor): return x < y
  17. def LessOrEqual(x:Tensor,y:Tensor): return x <= y
  18. def Greater(x:Tensor,y:Tensor): return x > y
  19. def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
  20. def Equal(x:Tensor,y:Tensor): return x == y
  21. def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
  22. def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
  23. def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
  24. def Mean(*data_0): return Sum(*data_0) / len(data_0)
  25. # NOTE: does not support saturate
  26. def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
  27. def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
  28. # **************** Simple Ops ****************
  29. # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
  30. def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype)
  31. def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
  32. if value is not None: return value
  33. if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
  34. if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
  35. if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
  36. if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
  37. if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
  38. def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
  39. def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
  40. def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
  41. def PRelu(X:Tensor, slope:Tensor):
  42. slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
  43. return (X > 0).where(X, X * slope)
  44. def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
  45. def ThresholdedRelu(X: Tensor, alpha=1.0): return (X > alpha).where(X, 0)
  46. def Softmax_1(x: Tensor, axis=1): return x.softmax(axis)
  47. def Softmax_13(x: Tensor, axis=-1): return x.softmax(axis)
  48. Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
  49. def LogSoftmax(x: Tensor, axis=-1): return x.log_softmax(axis)
  50. def Clip(x: Tensor, min=None, max=None): return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
  51. # NOTE ReduceProd would require a new llop
  52. def _axes(axes, noop_with_empty_axes):
  53. if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)): return to_python_const(axes)
  54. return [] if noop_with_empty_axes else None
  55. def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
  56. def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
  57. def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
  58. def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
  59. def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
  60. def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
  61. def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
  62. def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
  63. def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
  64. def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
  65. def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
  66. def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
  67. def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
  68. def Tile(x: Tensor, repeats): return x.repeat(to_python_const(repeats))
  69. def Range(start: Tensor, limit, delta): return Tensor.arange(start=to_python_const(start), stop=to_python_const(limit), step=to_python_const(delta))
  70. def Shape(data: Tensor, end=None, start=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
  71. def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
  72. def Flatten(x: Tensor, axis=1): return x.reshape(prod(x.shape[0:axis]), -1)
  73. def Reshape(data: Tensor, shape: Tensor, allowzero=0):
  74. return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(to_python_const(shape))])
  75. def Expand(x: Tensor, shape:Tensor): return x.expand(_broadcast_shape(x.shape, tuple(to_python_const(shape))))
  76. def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
  77. def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
  78. def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
  79. def Not(x:Tensor): return x.logical_not()
  80. def Asin(x): return Atan(x / (1 - x * x).sqrt())
  81. def Acos(x: Tensor):
  82. negate = (x < 0)
  83. x = x.abs()
  84. ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * (1.0 - x).sqrt()
  85. ret = ret - 2 * negate * ret
  86. return negate * math.pi + ret
  87. def Atan(y: Tensor):
  88. t1 = y.abs()
  89. t3 = (1 > t1).where(t1, t1.reciprocal())
  90. t4 = t3 * t3
  91. t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
  92. t3 = t0 * t3
  93. t3 = (t1 > 1).where(1.570796327 - t3, t3)
  94. return y.sign() * t3
  95. def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
  96. k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
  97. return x.triu(k) if upper else x.tril(k)
  98. def Squeeze(data: Tensor, axes):
  99. if isinstance(axes, Tensor): axes = to_python_const(axes)
  100. axes = [data._resolve_dim(x) for x in axes]
  101. return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
  102. def Unsqueeze(data: Tensor, axes):
  103. axes = sorted([x + data.ndim if x < 0 else x for x in to_python_const(axes)])
  104. new_shape = list(data.shape)
  105. for axis in axes: new_shape.insert(axis, 1)
  106. return data.reshape(new_shape)
  107. def Binarizer(x, threshold=0.0): return (x > threshold).float()
  108. def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
  109. if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
  110. return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
  111. def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
  112. def Concat(*xs: List[Tensor], axis): return Tensor.cat(*xs, dim=axis)
  113. def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
  114. def ConstantOfShape(x, value:Tensor=None):
  115. if value is None: value = 0.0
  116. shape = to_python_const(x)
  117. return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
  118. # **************** Complex Ops ****************
  119. def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
  120. ret = alpha * (A.transpose(transA) @ B.transpose(transB))
  121. if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
  122. return ret
  123. def Einsum(*Inputs: List[Tensor], equation): return Tensor.einsum(equation, Inputs)
  124. def CumSum(X:Tensor, axis:Tensor, exclusive=0, reverse=0):
  125. axis = to_python_const(axis)
  126. if axis < 0: axis += X.ndim
  127. if reverse: X = X.flip(axis)
  128. if exclusive:
  129. pad_arg, shrink_arg = [None] * X.ndim, [None] * X.ndim
  130. pad_arg[axis] = (1, 0)
  131. shrink_arg[axis] = (0, X.shape[axis])
  132. X = X.pad(tuple(pad_arg)).shrink(tuple(shrink_arg))
  133. if reverse: return X.cumsum(axis).flip(axis)
  134. return X.cumsum(axis)
  135. # TODO: this is copied from tinygrad/nn/__init__.py
  136. # spatial is from opset 7 and has since been removed
  137. def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
  138. if training_mode:
  139. x_detached = X.detach()
  140. current_mean = x_detached.mean(axis=(0,2,3))
  141. y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
  142. current_var = (y*y).mean(axis=(0,2,3))
  143. current_invstd = current_var.add(epsilon).rsqrt()
  144. running_mean = input_mean * momentum + current_mean * (1 - momentum)
  145. running_var = input_var * momentum + current_var * (1 - momentum)
  146. return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
  147. invstd = (input_var + epsilon).rsqrt()
  148. return X.batchnorm(scale, B, input_mean, invstd)
  149. def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
  150. axis = tuple(range(2, x.ndim))
  151. mean = x.mean(axis=axis, keepdim=True)
  152. invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
  153. return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
  154. def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
  155. assert stash_type == 1, "only float32 is supported"
  156. axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
  157. mean = x.mean(axis=axis, keepdim=True)
  158. return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
  159. def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
  160. return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
  161. # onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
  162. # numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
  163. def _format_padding(onnx_pads, ndims=None, axes=None):
  164. if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
  165. if ndims is None: ndims = len(onnx_pads) // 2
  166. if axes is None: axes = list(range(ndims))
  167. num_axes = len(axes)
  168. np_pads = [(0,0)] * ndims
  169. for i in range(num_axes):
  170. np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
  171. return np_pads
  172. def _padded(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0):
  173. if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
  174. elif ceil_mode:
  175. if strides is not None: strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
  176. if dilations is not None: dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
  177. out_spatial_shape = [math.ceil((sh - dil * (ker-1)-1)/st + 1) if ceil_mode else math.floor((sh - dil * (ker-1)-1)/st + 1) for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
  178. pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])]
  179. pad_shape = [[sh//2, sh-sh//2] for sh in pad_shape]
  180. # ceil_mode case follows NOTE in https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
  181. # so if any kernels start in right padded region, we decrease right pads to omit that kernel. Only omitting 1 kernel now.
  182. pad_shape = [[start,end-rpad] if (rpad := ks + st%(st-(((start+xs)%st)))) <= end else [start,end]
  183. for (start,end), ks, st, xs in zip(pad_shape, kernel_shape, strides, X.shape[-len(kernel_shape):])]
  184. pad_shape = flatten(pad_shape)
  185. pads = pad_shape[::2] + pad_shape[1::2]
  186. if pads is None: return X
  187. pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
  188. return X.pad(tuple(pads), value=constant_value)
  189. def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
  190. strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
  191. dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
  192. if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
  193. pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
  194. pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape])
  195. return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
  196. raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
  197. def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
  198. constant_value = value if constant_value is None else float(to_python_const(constant_value))
  199. seq_pads = list(pads) if isinstance(pads, tuple) else to_python_const(pads)
  200. seq_pads = [math.ceil(i) for i in seq_pads]
  201. seq_axes = to_python_const(axes) if axes is not None else None
  202. base_shape = x.shape
  203. pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
  204. if mode == "wrap":
  205. repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)]
  206. new_shape = [s*r for s,r in zip(base_shape, repeat_args)]
  207. shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh if dim[1]%sh != 0 else 0)) for dim, sh, nsh in zip(pads, base_shape, new_shape)]
  208. return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
  209. if mode == "reflect":
  210. for i,s in enumerate(x.shape):
  211. if pads[i] != (0,0):
  212. xL = x.flip(i).shrink(tuple((s-pads[i][0]-1, s_-1) if i_ == i else None for i_,s_ in enumerate(x.shape)))
  213. xR = x.flip(i).shrink(tuple((1, pads[i][1]+1) if i_ == i else None for i_ in range(x.ndim)))
  214. x = xL.cat(x, xR, dim=i)
  215. return x
  216. if mode == "edge":
  217. for i,s in enumerate(x.shape):
  218. if pads[i] != (0,0):
  219. xL = x.shrink(tuple((0,1) if i_ == i else None for i_ in range(x.ndim))).expand([pads[i][0] if i_ == i else None for i_ in range(x.ndim)])
  220. xR = x.shrink(tuple((s_-1, s_) if i_ == i else None for i_,s_ in enumerate(x.shape))).expand([pads[i][1] if i_ == i else None for i_ in range(x.ndim)])
  221. x = xL.cat(x, xR, dim=i)
  222. return x
  223. if mode == "constant":
  224. return _padded(x, seq_pads, axes=seq_axes, constant_value=constant_value)
  225. def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
  226. pixel_axes = tuple(range(2, X.ndim))
  227. ret = _padded(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
  228. ret = ret.avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
  229. if count_include_pad: return ret
  230. div = _padded(Tensor.ones(X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
  231. return ret / div
  232. def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
  233. pixel_axes = tuple(range(2, X.ndim))
  234. ret = _padded(X, pads, auto_pad, constant_value=-math.inf, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
  235. ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations).cast(X.dtype)
  236. ret_len, X_len = ret.numel(), X.numel()
  237. indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().unsqueeze(0).expand(ret_len, X_len)) * \
  238. Tensor.arange(X_len, dtype=dtypes.int64).unsqueeze(0).expand(ret_len, X_len)).sum(1).reshape(ret.shape)
  239. if storage_order: indices = indices.transpose(-2, -1)
  240. return ret, indices
  241. def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Optional[Tensor]=None, kernel_shape=None, pads=None, strides=None):
  242. out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
  243. outlength = prod(out_sh)
  244. xI = xI.flatten().unsqueeze(1).expand(None, outlength)
  245. arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
  246. xT = xT.flatten().unsqueeze(1).expand(None, outlength)
  247. ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
  248. if outshape is not None and (outshape := to_python_const(outshape)) != ret.shape:
  249. diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
  250. pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
  251. ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
  252. return ret
  253. def Conv(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
  254. if auto_pad != "NOTSET":
  255. padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
  256. else:
  257. # reorder padding
  258. padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0
  259. return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
  260. def ConvTranspose(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
  261. if kernel_shape is None: kernel_shape = W.shape[2:]
  262. if isinstance(strides, int): strides = [strides]*(W.ndim-2)
  263. if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2)
  264. if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2)
  265. out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else []
  266. if pads is None:
  267. if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)]
  268. if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
  269. else:
  270. total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)]
  271. pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding])
  272. pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
  273. else:
  274. if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))]
  275. if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
  276. return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
  277. def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
  278. b, c, h, w = X.shape
  279. if mode == "DCR":
  280. return X.reshape(b, blocksize, blocksize, c // (blocksize**2), h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
  281. elif mode == "CRD":
  282. return X.reshape(b, c // (blocksize ** 2), blocksize, blocksize, h, w).permute(0, 1, 4, 2, 5, 3).reshape(b, c // (blocksize ** 2), h * blocksize, w * blocksize)
  283. def SpaceToDepth(X:Tensor, blocksize:int):
  284. b, c, h, w = X.shape
  285. return X.reshape(b, c, h // blocksize, blocksize, w // blocksize, blocksize).permute(0, 3, 5, 1, 2, 4).reshape(b, c * (blocksize**2), h // blocksize, w // blocksize)
  286. # Reimplemented here because you need legacy RNG for passing ONNX tests.
  287. def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
  288. if isinstance(ratio, Tensor) and not ratio.shape: ratio = to_python_const(ratio) # ratio and tensor is passed in as Tensor with shape: ()
  289. if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = to_python_const(training_mode)
  290. if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
  291. rng = np.random.RandomState(seed)
  292. if isinstance(ratio, Tensor): ratio = ratio.item()
  293. mask = Tensor(rng.random(data.shape) >= ratio, requires_grad=False, device=data.device)
  294. return data * mask * (1/(1.0 - ratio)), mask
  295. def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
  296. bs, c, iy, ix = x.shape
  297. return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
  298. def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)):
  299. mean = x.mean(axis, keepdim=True)
  300. std = x.std(axis, keepdim=True, correction=0)
  301. return (x - mean) / (std + 1e-9)
  302. def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
  303. N, C, i_shape = x.shape[0], x.shape[1], x.shape
  304. t_shape = target.shape
  305. if len(x.shape) != 3:
  306. x = x.reshape((N, C, -1))
  307. target = target.reshape((N, -1))
  308. if weight is not None:
  309. mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
  310. weight = (mask * weight).sum(axis=-1)
  311. if ignore_index is not None:
  312. cond = target == ignore_index
  313. weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
  314. mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(x.ndim -2))
  315. loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
  316. if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
  317. if reduction == "sum": return loss.sum()
  318. return loss.reshape(t_shape) if len(i_shape) != 3 else loss
  319. def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
  320. _N, C, *s_dimensions = scores.shape
  321. if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
  322. mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
  323. y = scores.log_softmax(axis=1)
  324. loss = (mask * -y).sum(1)
  325. if weights is not None:
  326. weights = weights[labels, ...]
  327. loss = loss * weights
  328. if reduction == "mean": loss = loss.sum() / ((loss != 0).sum() if weights is None else weights.sum())
  329. elif reduction == "sum": loss = loss.sum()
  330. return loss, y
  331. def ArrayFeatureExtractor(x: Tensor, indices: Tensor): return x[..., indices]
  332. def Gather(x: Tensor, indices: Tensor, axis=0):
  333. if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
  334. x_sh = list(x.shape)
  335. ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
  336. if indices.ndim > 1: indices = indices.flatten()
  337. indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)]
  338. args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices]
  339. return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
  340. # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
  341. return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
  342. def GatherElements(x: Tensor, indices: Tensor, axis):
  343. indices = (indices < 0).where(x.shape[axis], 0) + indices
  344. return x.gather(axis, indices)
  345. # TODO clean this up, it's taking the longest in CI
  346. def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel',
  347. cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch',
  348. mode='nearest', nearest_mode='round_prefer_floor'):
  349. def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
  350. def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
  351. if nearest_mode == "round_prefer_floor": ret = (x_resized - 0.5).ceil()
  352. elif nearest_mode == "round_prefer_ceil": ret = (x_resized + 0.5).floor()
  353. elif nearest_mode == "floor": ret = x_resized.floor()
  354. elif nearest_mode == "ceil": ret = x_resized.ceil()
  355. return ret.cast(dtypes.int32).clip(0, x_len-1)
  356. def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
  357. if coordinate_transformation_mode == "half_pixel":
  358. x_out = (x_out + 0.5) / scales_[-1] - 0.5
  359. y_out = (y_out + 0.5) / scales_[-2] - 0.5
  360. elif coordinate_transformation_mode == "align_corners":
  361. x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
  362. y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
  363. elif coordinate_transformation_mode == "asymmetric":
  364. x_out = x_out / scales_[-1]
  365. y_out = y_out / scales_[-2]
  366. elif coordinate_transformation_mode == "half_pixel_symmetric":
  367. x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_[-1] - 0.5
  368. y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_[-2] - 0.5
  369. elif coordinate_transformation_mode == "pytorch_half_pixel":
  370. x_out = (x_out + 0.5) / scales_[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
  371. y_out = (y_out + 0.5) / scales_[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
  372. elif coordinate_transformation_mode == "tf_crop_and_resize":
  373. x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
  374. y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
  375. return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
  376. if roi is not None:
  377. roi = to_python_const(roi)
  378. roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
  379. roi_ = [(1,1)] * 4
  380. if axes is not None:
  381. for a,r in zip(axes, roi):
  382. roi_[a] = r
  383. roi = roi_
  384. if scales is not None:
  385. scales = to_python_const(scales)
  386. if axes is not None:
  387. scales_ = [1]*X.ndim
  388. for a,s in zip(axes, scales):
  389. scales_[a] = s
  390. scales = scales_
  391. elif sizes is not None:
  392. sizes = to_python_const(sizes)
  393. scales = []
  394. if axes is not None:
  395. sizes_ = [1]*X.ndim
  396. for a,s in zip(axes, sizes):
  397. sizes_[a] = s
  398. scales.append(s/X.shape[a])
  399. sizes = sizes_
  400. else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
  401. if keep_aspect_ratio_policy == "not_larger":
  402. scale = min(scales)
  403. sizes = list(X.shape[:-2]) + [math.ceil(sh*scale) for sh in X.shape[-2:]]
  404. elif keep_aspect_ratio_policy == "not_smaller":
  405. scale = max(scales)
  406. sizes = list(X.shape[:-2]) + [math.ceil(sh*scale) for sh in X.shape[-2:]]
  407. output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
  408. output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
  409. scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
  410. x_out = Tensor.arange(output_shape[-1], dtype=dtypes.default_float)
  411. y_out = Tensor.arange(output_shape[-2], dtype=dtypes.default_float)
  412. if mode == "nearest":
  413. x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
  414. x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
  415. y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
  416. return _nearest_gather(X, x_out, y_out)
  417. if mode == "linear":
  418. x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi)
  419. ret = []
  420. for y in to_python_const(y_out):
  421. for x in to_python_const(x_out):
  422. x_floor, y_floor = int(x), int(y)
  423. y_shrink = (y_floor, math.ceil(y)+1)
  424. x_shrink = (x_floor, math.ceil(x)+1)
  425. corners = to_python_const(X.shrink((None, None, y_shrink, x_shrink)))[0][0]
  426. wx, wy = math.ceil(x) - x, math.ceil(y) - y
  427. if x == x_floor and y == y_floor:
  428. weighted = corners[0][0]
  429. elif x == x_floor:
  430. weighted = corners[0][0] * wy + corners[1][0] * (1-wy)
  431. elif y == y_floor:
  432. weighted = corners[0][0] * wx + corners[0][1] * (1-wx)
  433. else:
  434. weighted = (corners[0][0] * wx + corners[0][1] * (1-wx)) * wy + \
  435. (corners[1][0] * (wx) + corners[1][1] * (1-wx)) * (1-wy)
  436. ret.append(weighted)
  437. return Tensor(ret).reshape(output_shape)
  438. if mode == "cubic":
  439. raise NotImplementedError("cubic interpolation is not implemented")
  440. def CenterCropPad(t: Tensor, shape: Tensor, axes=None):
  441. if not axes: axes = list(range(t.ndim))
  442. shrink_arg = [None] * t.ndim
  443. pad_arg = [None] * t.ndim
  444. shape = to_python_const(shape)
  445. for s, x in zip(shape, axes):
  446. tx = t.shape[x]
  447. if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
  448. elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
  449. return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
  450. def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
  451. depth = int(to_python_const(depth))
  452. indices, rank = (indices < 0).where(indices+depth, indices), indices.ndim
  453. if axis < 0: axis += rank + 1
  454. ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
  455. cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
  456. return cond.where(values[1], values[0])
  457. def Erf(x: Tensor):
  458. t = 1.0 / (1.0 + 0.3275911 * x.abs())
  459. term1 = 0.254829592 * t
  460. term2 = -0.284496736 * t ** 2
  461. term3 = 1.421413741 * t ** 3
  462. term4 = -1.453152027 * t ** 4
  463. term5 = 1.061405429 * t ** 5
  464. y = (term1 + term2 + term3 + term4 + term5)
  465. z = 1.0 - y * (-x * x).exp()
  466. return (x > 0).where(z, -z)
  467. def Compress(inp: Tensor, condition: Tensor, axis=None):
  468. if axis is None:
  469. inp = inp.flatten()
  470. axis = 0
  471. if axis < 0: axis += inp.ndim
  472. con_np = to_python_const(condition)
  473. con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
  474. return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
  475. def EyeLike(x: Tensor, dtype=None, k=0):
  476. if dtype is None: dtype = x.dtype
  477. else: dtype = DTYPE_MAP[int(dtype)]
  478. dim = min(x.shape)
  479. if x.shape[0] == x.shape[1]:
  480. return Tensor.eye(dim, dtype=dtype)
  481. padarg = tuple(None if d == dim else (k, d-dim-k) for d in x.shape)
  482. return Tensor.eye(dim, dtype=dtype).pad(padarg)
  483. def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
  484. def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
  485. return (x == float("inf")) * bool(detect_positive) + (x == float("-inf")) * bool(detect_negative)
  486. def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1, block_size=0):
  487. def numpy_repeat(t: Tensor, axis, repeats, out_shape):
  488. t = t.reshape(tuple(-1 if i == axis-1 else 1 if i == axis else sh for i,sh in enumerate(t.shape)))
  489. return t.repeat([repeats if i == axis else 1 for i in range(t.ndim)]).reshape(out_shape)
  490. if axis < 0: axis += x.ndim
  491. if block_size:
  492. x_zer, x_sc = numpy_repeat(x_zero_point, axis, block_size, x.shape), numpy_repeat(x_scale, axis, block_size, x.shape)
  493. else:
  494. x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
  495. x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
  496. return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
  497. def IsNaN(x: Tensor): return x != x
  498. # copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
  499. # without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc
  500. def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
  501. try:
  502. import PIL.Image
  503. except ImportError as e:
  504. raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
  505. img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream, tobytes=True)))
  506. if pixel_format == "BGR":
  507. return Tensor(np.array(img))[:, :, ::-1]
  508. if pixel_format == "RGB":
  509. return Tensor(np.array(img))
  510. if pixel_format == "Grayscale":
  511. img = img.convert("L")
  512. decoded = Tensor(np.array(img))
  513. return decoded.unsqueeze(-1) # (H, W) to (H, W, 1)
  514. raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
  515. def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
  516. _, _, *data_sz = to_python_const(size)
  517. size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz)
  518. stackable = [original_grid]
  519. for dim, dim_sz in enumerate(data_sz):
  520. a = Tensor.arange(-1, 1.0001, 2/(dim_sz-1)) if align_corners == 1 else Tensor.arange(-1+1/dim_sz, 1, 2/dim_sz)
  521. if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable]
  522. elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable]
  523. else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable]
  524. original_grid = Tensor.stack(*stackable, dim=len(data_sz))
  525. if original_grid.ndim == 3:
  526. N, dim_2d, dim_homo = theta.shape
  527. assert dim_2d == 2 and dim_homo == 3
  528. H, W, dim_homo = original_grid.shape
  529. assert dim_homo == 3
  530. original_grid = original_grid.reshape(H*W, dim_homo).transpose()
  531. return theta.matmul(original_grid).permute(0,2,1).reshape(N, H, W, dim_2d)
  532. assert original_grid.ndim == 4
  533. N, dim_3d, dim_homo = theta.shape
  534. assert dim_3d == 3 and dim_homo == 4
  535. D, H, W, dim_homo = original_grid.shape
  536. assert dim_homo == 4
  537. original_grid = original_grid.reshape(D*H*W, dim_homo).transpose()
  538. return theta.matmul(original_grid).permute(0,2,1).reshape(N, D, H, W, dim_3d)
  539. # **************** com.microsoft Ops ****************
  540. def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
  541. if epsilon is None: epsilon=1e-12
  542. x = x + skip + bias
  543. return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
  544. def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
  545. # this is tanh approamixated
  546. return (x + bias).gelu()
  547. def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
  548. # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
  549. assert (segment_ids is None) is (segment_embedding is None)
  550. assert (mask is None) is (mask_index_type is None)
  551. assert mask is None, "functionality not supported yet" # TODO
  552. input_shape = input_ids.shape
  553. seq_length = input_shape[1]
  554. compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
  555. vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
  556. def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
  557. vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
  558. return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
  559. # bert embedding layer
  560. if epsilon is None: epsilon = 1e-12
  561. if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
  562. wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
  563. pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
  564. seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
  565. embedding_sum = wrd_embedding_res + pos_embedding_res
  566. if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
  567. out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
  568. return out, None, embedding_sum
  569. def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
  570. # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
  571. assert num_heads is not None # required
  572. assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
  573. assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
  574. hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
  575. if unidirectional: # gpt-style
  576. assert hidden_size == v_hidden_size
  577. xqkv = x.linear(weights, bias)
  578. xq, xk, xv = [xqkv._slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
  579. else: # bert-style
  580. wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
  581. bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
  582. xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
  583. xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
  584. if past is not None:
  585. xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
  586. present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
  587. def attn(query, key, value, attn_mask):
  588. query_length, key_length = query.shape[-2], key.shape[-2]
  589. cdim = max(query_length, key_length) + 1
  590. attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
  591. # This is where Tensor.scaled_dot_product_attention differs:
  592. causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
  593. masked = Tensor.where(causal_mask, attn_weights, -math.inf)
  594. if attn_mask is not None: masked = masked + attn_mask
  595. return masked.softmax(-1) @ value
  596. bsz, _, seq_len, _ = xq.shape
  597. out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
  598. return out, present
  599. # **************** ai.onnx.preview.training Ops ****************
  600. # TODO not entirely sure these optimizers are correct
  601. def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
  602. groups = len(inputs) // 3
  603. grouped_inputs = [inputs[i::groups] for i in range(groups)]
  604. r = to_python_const(R / (1 + T * decay_factor))
  605. ret = []
  606. for X, G, H in grouped_inputs:
  607. X.grad = norm_coefficient * X + G
  608. X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
  609. H.assign(H.detach() + X.grad * X.grad).realize()
  610. H_adaptive = H.sqrt() + epsilon
  611. X.assign(X.detach() - r * X.grad / H_adaptive)
  612. ret.extend([X, H])
  613. ret = ret[::2] + ret[1::2]
  614. return tuple(ret)
  615. def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
  616. groups = len(inputs) // 3
  617. grouped_inputs = [inputs[i::groups] for i in range(groups)]
  618. T, R.requires_grad = to_python_const(T), False
  619. beta_adjusted = beta if T > 0 else 1
  620. ret = []
  621. for X, G, V in grouped_inputs:
  622. X.grad = (norm_coefficient * X + G).realize()
  623. X.grad.requires_grad, V.requires_grad = False, False
  624. V.assign(alpha * V + beta_adjusted * X.grad).realize()
  625. if mode == "standard": X.assign(X.detach() - R * V).realize()
  626. elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
  627. ret.extend([X, V])
  628. ret = ret[::2] + ret[1::2]
  629. return tuple(ret)
  630. # copied from tinygrad/nn/optim.py: LAMB with some edits
  631. def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
  632. groups = len(inputs) // 4
  633. grouped_inputs = [inputs[i::groups] for i in range(groups)]
  634. T, R.requires_grad = to_python_const(T), False
  635. ret = []
  636. for X, G, V, H in grouped_inputs:
  637. X.grad = (norm_coefficient * X + G).realize()
  638. V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
  639. V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
  640. H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
  641. up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon)
  642. X.assign(X.detach() - R * up).realize()
  643. X = (1 - norm_coefficient_post) * X
  644. ret.extend([X, V, H])
  645. ret = ret[::3] + ret[1::3] + ret[2::3]
  646. return tuple(ret)