function.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """This is where the forwards and backwards passes live."""
  2. import math
  3. from typing import Tuple, Optional
  4. from tinygrad.helpers import argsort
  5. from tinygrad.dtype import dtypes, DType, sum_acc_dtype
  6. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
  7. from tinygrad.tensor import Function
  8. from tinygrad.lazy import LazyBuffer
  9. from tinygrad.shape.symbolic import sint
  10. class Contiguous(Function):
  11. def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
  12. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
  13. class ContiguousBackward(Function):
  14. def forward(self, x:LazyBuffer) -> LazyBuffer: return x
  15. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
  16. class Cast(Function):
  17. def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
  18. self.input_dtype, self.bitcast = x.dtype, bitcast
  19. return x.cast(dtype, bitcast)
  20. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
  21. # ************* unary ops *************
  22. class Neg(Function):
  23. def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
  24. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
  25. class Reciprocal(Function):
  26. def forward(self, x:LazyBuffer) -> LazyBuffer:
  27. self.ret = x.e(UnaryOps.RECIP)
  28. return self.ret
  29. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  30. return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
  31. class Sin(Function):
  32. def forward(self, x:LazyBuffer) -> LazyBuffer:
  33. self.x = x
  34. return x.e(UnaryOps.SIN)
  35. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  36. return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
  37. # NOTE: maximum(x, 0) behaves differently where x=0
  38. class Relu(Function):
  39. def forward(self, x:LazyBuffer) -> LazyBuffer:
  40. self.ret = x.e(BinaryOps.MAX, x.const(0))
  41. return self.ret
  42. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  43. return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
  44. class Log(Function):
  45. def forward(self, x:LazyBuffer) -> LazyBuffer:
  46. self.x = x
  47. return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
  48. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
  49. class Exp(Function):
  50. def forward(self, x:LazyBuffer) -> LazyBuffer:
  51. self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
  52. return self.ret
  53. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
  54. class Sqrt(Function):
  55. def forward(self, x:LazyBuffer) -> LazyBuffer:
  56. self.ret = x.e(UnaryOps.SQRT)
  57. return self.ret
  58. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  59. return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
  60. # NOTE: the implicit derivative of sigmoid is not stable
  61. # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
  62. # TODO: have the backend automatically find this
  63. class Sigmoid(Function):
  64. def forward(self, x:LazyBuffer) -> LazyBuffer:
  65. self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
  66. return self.ret
  67. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  68. return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
  69. class Sign(Function):
  70. def forward(self, x:LazyBuffer) -> LazyBuffer:
  71. return x.e(BinaryOps.CMPNE, x.const(0)).e(
  72. TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
  73. # backward always return 0 to match torch
  74. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
  75. # ************* binary ops *************
  76. class Less(Function):
  77. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
  78. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
  79. class Neq(Function):
  80. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
  81. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
  82. class Xor(Function):
  83. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
  84. class BitwiseAnd(Function):
  85. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
  86. class BitwiseOr(Function):
  87. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
  88. class Threefry(Function):
  89. def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed)
  90. class Add(Function):
  91. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
  92. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
  93. return grad_output if self.needs_input_grad[0] else None, \
  94. grad_output if self.needs_input_grad[1] else None
  95. class Mul(Function):
  96. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
  97. self.x, self.y = x, y
  98. return x.e(BinaryOps.MUL, y)
  99. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
  100. return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
  101. self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
  102. class Div(Function):
  103. def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
  104. self.x, self.y = x, y
  105. return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
  106. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
  107. return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
  108. grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
  109. # ************* ternary ops *************
  110. class Where(Function):
  111. def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
  112. self.x = x
  113. return self.x.e(TernaryOps.WHERE, y, z)
  114. def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
  115. return None, \
  116. self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
  117. self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
  118. # ************* reduce ops *************
  119. class Sum(Function):
  120. def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
  121. self.input_shape = x.shape
  122. return x.r(ReduceOps.SUM, axis)
  123. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
  124. class Max(Function):
  125. def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
  126. self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
  127. return self.ret
  128. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  129. # 1s in locations where the max was chosen (can be two locations)
  130. max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
  131. self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
  132. div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
  133. return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
  134. # ************* movement ops *************
  135. # NOTE: this is sum in reverse
  136. class Expand(Function):
  137. def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
  138. self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
  139. return x.expand(shape)
  140. def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
  141. return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
  142. class Reshape(Function):
  143. def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
  144. self.input_shape = x.shape
  145. return x.reshape(shape)
  146. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
  147. class Permute(Function):
  148. def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
  149. self.input_order = order
  150. return x.permute(order)
  151. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
  152. class Pad(Function):
  153. def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
  154. self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
  155. return x.pad(arg)
  156. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
  157. class Shrink(Function):
  158. def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
  159. self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
  160. return x.shrink(arg)
  161. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
  162. class Flip(Function):
  163. def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
  164. self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
  165. return x.stride(self.arg)
  166. def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)