| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- """This is where the forwards and backwards passes live."""
- import math
- from typing import Tuple, Optional
- from tinygrad.helpers import argsort
- from tinygrad.dtype import dtypes, DType, sum_acc_dtype
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
- from tinygrad.tensor import Function
- from tinygrad.lazy import LazyBuffer
- from tinygrad.shape.symbolic import sint
- class Contiguous(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
- class ContiguousBackward(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
- class Cast(Function):
- def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
- self.input_dtype, self.bitcast = x.dtype, bitcast
- return x.cast(dtype, bitcast)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
- # ************* unary ops *************
- class Neg(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
- class Reciprocal(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.ret = x.e(UnaryOps.RECIP)
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
- class Sin(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.x = x
- return x.e(UnaryOps.SIN)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
- # NOTE: maximum(x, 0) behaves differently where x=0
- class Relu(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.ret = x.e(BinaryOps.MAX, x.const(0))
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
- class Log(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.x = x
- return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
- class Exp(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
- class Sqrt(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.ret = x.e(UnaryOps.SQRT)
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
- # NOTE: the implicit derivative of sigmoid is not stable
- # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
- # TODO: have the backend automatically find this
- class Sigmoid(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
- class Sign(Function):
- def forward(self, x:LazyBuffer) -> LazyBuffer:
- return x.e(BinaryOps.CMPNE, x.const(0)).e(
- TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
- # backward always return 0 to match torch
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
- # ************* binary ops *************
- class Less(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
- class Neq(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
- class Xor(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
- class BitwiseAnd(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
- class BitwiseOr(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
- class Threefry(Function):
- def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed)
- class Add(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
- return grad_output if self.needs_input_grad[0] else None, \
- grad_output if self.needs_input_grad[1] else None
- class Mul(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
- self.x, self.y = x, y
- return x.e(BinaryOps.MUL, y)
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
- return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
- self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
- class Div(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
- self.x, self.y = x, y
- return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
- return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
- grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
- # ************* ternary ops *************
- class Where(Function):
- def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
- self.x = x
- return self.x.e(TernaryOps.WHERE, y, z)
- def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
- return None, \
- self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
- self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
- # ************* reduce ops *************
- class Sum(Function):
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
- self.input_shape = x.shape
- return x.r(ReduceOps.SUM, axis)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
- class Max(Function):
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
- self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
- return self.ret
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- # 1s in locations where the max was chosen (can be two locations)
- max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
- self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
- div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
- return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
- # ************* movement ops *************
- # NOTE: this is sum in reverse
- class Expand(Function):
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
- self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
- return x.expand(shape)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
- return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
- class Reshape(Function):
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
- self.input_shape = x.shape
- return x.reshape(shape)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
- class Permute(Function):
- def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
- self.input_order = order
- return x.permute(order)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
- class Pad(Function):
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
- self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
- return x.pad(arg)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
- class Shrink(Function):
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
- self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
- return x.shrink(arg)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
- class Flip(Function):
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
- self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
- return x.stride(self.arg)
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|