"""This is where the forwards and backwards passes live.""" import math from typing import Tuple, Optional from tinygrad.helpers import argsort from tinygrad.dtype import dtypes, DType, sum_acc_dtype from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer from tinygrad.shape.symbolic import sint class Contiguous(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous() def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output class ContiguousBackward(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous() class Cast(Function): def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer: self.input_dtype, self.bitcast = x.dtype, bitcast return x.cast(dtype, bitcast) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast) # ************* unary ops ************* class Neg(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG) class Reciprocal(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(UnaryOps.RECIP) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret) class Sin(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.e(UnaryOps.SIN) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output) # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(BinaryOps.MAX, x.const(0)) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output) class Log(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP)) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output) class Sqrt(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(UnaryOps.SQRT) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP)) # NOTE: the implicit derivative of sigmoid is not stable # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e # TODO: have the backend automatically find this class Sigmoid(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output) class Sign(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, x.const(0)).e( TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0)) # backward always return 0 to match torch def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0) # ************* binary ops ************* class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None class Neq(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None class Xor(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y) class BitwiseAnd(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y) class BitwiseOr(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y) class Threefry(Function): def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed) class Add(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ grad_output if self.needs_input_grad[1] else None class Mul(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x.e(BinaryOps.MUL, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \ self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None class Div(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \ grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501 # ************* ternary ops ************* class Where(Function): def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: self.x = x return self.x.e(TernaryOps.WHERE, y, z) def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: return None, \ self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None # ************* reduce ops ************* class Sum(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape return x.r(ReduceOps.SUM, axis) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \ self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG)) div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape) return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* movement ops ************* # NOTE: this is sum in reverse class Expand(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so) return x.expand(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape return x.reshape(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape) class Permute(Function): def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order return x.permute(order) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order)) class Pad(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) return x.pad(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg) class Shrink(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))]) return x.stride(self.arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)