| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- from __future__ import annotations
- from typing import Union, Tuple, Any, List, Dict, Callable
- import functools, hashlib, math, operator, ctypes, struct
- from enum import Enum, auto
- from dataclasses import dataclass
- from tinygrad.helpers import prod, dedup
- from tinygrad.dtype import dtypes, DType, ConstType
- from tinygrad.shape.symbolic import Variable, sint
- from tinygrad.shape.shapetracker import ShapeTracker
- # these are the llops your accelerator must implement, along with toCpu
- # the Enum class doesn't work with mypy, this is static. sorry it's ugly
- # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
- # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
- class UnaryOps(Enum):
- """A -> A (elementwise)"""
- EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
- class BinaryOps(Enum):
- """A + A -> A (elementwise)"""
- ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
- SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
- class TernaryOps(Enum):
- """A + A + A -> A (elementwise)"""
- WHERE = auto(); MULACC = auto() # noqa: E702
- class ReduceOps(Enum):
- """A -> B (reduce)"""
- SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
- class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
- class MetaOps(Enum):
- EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto() # noqa: E702
- Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
- # do not preserve f(0) = 0
- UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
- @dataclass(frozen=True)
- class MemBuffer:
- idx: int
- dtype: DType
- st: ShapeTracker
- @dataclass(frozen=True)
- class ConstBuffer:
- val: ConstType | Variable
- dtype: DType
- st: ShapeTracker
- @dataclass(frozen=True)
- class KernelInfo:
- local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
- upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
- @dataclass(frozen=True, eq=False)
- class LazyOp:
- op: Op
- src: Tuple[LazyOp, ...] = ()
- arg: Any = None
- def cached_compare(self, x, context):
- if id(self) == id(x): return True
- if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
- if (key := (id(self), id(x))) in context: return context[key]
- ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
- return ret
- def __eq__(self, x): return self.cached_compare(x, context={})
- def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
- @functools.cached_property
- def dtype(self) -> DType:
- if self.op in BufferOps: return self.arg.dtype
- if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
- if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
- return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
- @functools.cached_property
- def full_shape(self):
- if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
- return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
- @functools.cached_property
- def key(self) -> bytes:
- return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
- @functools.cached_property
- def hash(self): return hash((self.op, self.src, self.arg))
- def __hash__(self): return self.hash
- @functools.cached_property
- def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
- def vars(self) -> List[Variable]:
- extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
- const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
- return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
- # TODO: support non-lazyop
- def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
- def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
- def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
- def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
- def eq(self, x:LazyOp): return -self.ne(x)
- def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
- @staticmethod
- def const(val, dtype:DType, shape:Tuple[sint, ...]):
- return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
- # **************** independent FlopCounter ****************
- @dataclass
- class FlopCounter:
- shape: Tuple[int, ...]
- flops: sint
- mem: Dict[int, int]
- @property
- def mem_estimate(self): return sum(self.mem.values())
- def consume_flops(self):
- self.flops, ret = 0, self.flops
- return ret
- InterpretedFlopCounter: Dict[Op, Callable] = {
- BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
- BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
- BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
- UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
- UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
- **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
- **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
- **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
- TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
- @functools.lru_cache(None)
- def get_lazyop_info(ast:LazyOp) -> FlopCounter:
- @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
- def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
- return run_ast(ast)
- # **************** ops in python ****************
- def hook_overflow(dv, fxn):
- def wfxn(*args):
- try: return fxn(*args)
- except OverflowError: return dv
- return wfxn
- python_alu: Dict[Op, Callable] = {
- UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
- UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
- UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
- UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
- UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
- UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
- BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
- BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
- BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
- BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
- BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
- TernaryOps.MULACC: lambda x,y,z: (x*y)+z,
- TernaryOps.WHERE: lambda x,y,z: y if x else z}
- def truncate_fp16(x):
- try:
- x = float(x)
- struct.pack("@e", x)
- return x
- except OverflowError: return math.copysign(math.inf, x)
- truncate: Dict[DType, Callable] = {dtypes.bool: bool,
- # TODO: bfloat16
- dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
- dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
- dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
- dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
- if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value, dtypes.bigint: lambda x: x }
- def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
- def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
- # the living definition of LazyOps
- def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
- assert ast.op is MetaOps.KERNEL, "must be SINK"
- sts: Dict[LazyOp, ShapeTracker] = {}
- def dfs(op:LazyOp, st:ShapeTracker):
- if op in sts: return
- # restore globals from the two stage reduce
- if op.op is BufferOps.LOAD and op.arg.idx == -1:
- dfs(local_reduce:=op.src[0].src[0], op.arg.st)
- return sts.setdefault(op, sts[local_reduce])
- for x in op.src: dfs(x, st)
- # only reduceop is allowed to change shape, limited to turning n to 1
- if op.op in ReduceOps:
- axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
- assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
- st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
- else:
- # movementops are pushed to the edges with LOAD
- if op.op in BufferOps: st = op.arg.st
- else: st = sts[op.src[0]]
- for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
- sts[op] = st
- for i, out in enumerate(ast.src):
- assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
- assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
- assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
- dfs(out, out.arg.st)
- return sts
|