from typing import Optional, List, Tuple, Dict import functools from dataclasses import dataclass from tinygrad.helpers import to_function_name from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.shape.symbolic import sym_infer, sint, Variable from tinygrad.dtype import DType @dataclass(frozen=True) class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) dims: Tuple[int,int,int] # N, M, K dtype_in: DType # dtype for A and B dtype_out: DType # dtype for C and D threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name]) @dataclass(frozen=True) class Program: name:str src:str dname:str global_size:Optional[List[int]]=None local_size:Optional[List[int]]=None uops:Optional[UOpGraph]=None op_estimate:sint=0 mem_estimate:sint=0 @functools.cached_property def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars() @functools.cached_property def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals() @functools.cached_property def outcount(self) -> int: return sum(x[1] for x in self.globals) @functools.cached_property def function_name(self) -> str: return to_function_name(self.name) def launch_dims(self, var_vals:Dict[Variable, int]): global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None return global_size, local_size class Renderer: device: str = "" suffix: str = "" # TODO: make this generic with a list of supported types supports_float4: bool = True has_local: bool = True has_shared: bool = True # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now shared_max: int = 32768 tensor_cores: List[TensorCore] = [] def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")