__init__.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Optional, List, Tuple, Dict
  2. import functools
  3. from dataclasses import dataclass
  4. from tinygrad.helpers import to_function_name
  5. from tinygrad.codegen.uopgraph import UOpGraph
  6. from tinygrad.shape.symbolic import sym_infer, sint, Variable
  7. from tinygrad.dtype import DType
  8. @dataclass(frozen=True)
  9. class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
  10. dims: Tuple[int,int,int] # N, M, K
  11. dtype_in: DType # dtype for A and B
  12. dtype_out: DType # dtype for C and D
  13. threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
  14. thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
  15. def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
  16. @dataclass(frozen=True)
  17. class Program:
  18. name:str
  19. src:str
  20. dname:str
  21. global_size:Optional[List[int]]=None
  22. local_size:Optional[List[int]]=None
  23. uops:Optional[UOpGraph]=None
  24. op_estimate:sint=0
  25. mem_estimate:sint=0
  26. @functools.cached_property
  27. def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
  28. @functools.cached_property
  29. def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
  30. @functools.cached_property
  31. def outcount(self) -> int: return sum(x[1] for x in self.globals)
  32. @functools.cached_property
  33. def function_name(self) -> str: return to_function_name(self.name)
  34. def launch_dims(self, var_vals:Dict[Variable, int]):
  35. global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
  36. local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
  37. return global_size, local_size
  38. class Renderer:
  39. device: str = ""
  40. suffix: str = ""
  41. # TODO: make this generic with a list of supported types
  42. supports_float4: bool = True
  43. has_local: bool = True
  44. has_shared: bool = True
  45. # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
  46. global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
  47. local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
  48. shared_max: int = 32768
  49. tensor_cores: List[TensorCore] = []
  50. def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")