dtype.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
  2. from dataclasses import dataclass
  3. import functools
  4. from tinygrad.helpers import getenv
  5. ConstType = Union[float, int, bool]
  6. @dataclass(frozen=True, order=True)
  7. class DType:
  8. priority: int # this determines when things get upcasted
  9. itemsize: int
  10. name: str
  11. fmt: Optional[str]
  12. count: int
  13. def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
  14. def vec(self, sz:int):
  15. assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
  16. return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
  17. def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
  18. # dependent typing?
  19. @dataclass(frozen=True, repr=False)
  20. class ImageDType(DType):
  21. shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
  22. base: DType
  23. def scalar(self): return self.base
  24. def vec(self, sz:int): return self.base.vec(sz)
  25. def __repr__(self): return f"dtypes.{self.name}({self.shape})"
  26. # @dataclass(frozen=True, init=False, repr=False, eq=False)
  27. class PtrDType(DType):
  28. def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
  29. def __repr__(self): return f"ptr.{super().__repr__()}"
  30. def __hash__(self): return super().__hash__()
  31. def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
  32. def __ne__(self, dt): return not (self == dt)
  33. class dtypes:
  34. @staticmethod
  35. def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
  36. @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
  37. def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.bigint) or dtypes.is_unsigned(x)
  38. @staticmethod
  39. def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
  40. @staticmethod
  41. def from_py(x) -> DType:
  42. if x.__class__ is float: return dtypes.default_float
  43. if x.__class__ is int: return dtypes.default_int
  44. if x.__class__ is bool: return dtypes.bool
  45. # put this in the last is faster because there are more items than lists/tuples to check
  46. if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
  47. raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
  48. @staticmethod
  49. def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
  50. @staticmethod
  51. def min(dtype:DType):
  52. if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
  53. return -float("inf") if dtypes.is_float(dtype) else False
  54. @staticmethod
  55. def max(dtype:DType):
  56. if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
  57. return float("inf") if dtypes.is_float(dtype) else True
  58. @staticmethod
  59. def fields() -> Dict[str, DType]: return DTYPES_DICT
  60. bigint: Final[DType] = DType(-1, 0, "bigint", None, 1) # arbitrary precision integer
  61. bool: Final[DType] = DType(0, 1, "bool", '?', 1)
  62. int8: Final[DType] = DType(1, 1, "char", 'b', 1)
  63. uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
  64. int16: Final[DType] = DType(3, 2, "short", 'h', 1)
  65. uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
  66. int32: Final[DType] = DType(5, 4, "int", 'i', 1)
  67. uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
  68. int64: Final[DType] = DType(7, 8, "long", 'l', 1)
  69. uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
  70. float16: Final[DType] = DType(9, 2, "half", 'e', 1)
  71. # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
  72. bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
  73. float32: Final[DType] = DType(11, 4, "float", 'f', 1)
  74. float64: Final[DType] = DType(12, 8, "double", 'd', 1)
  75. # dtype aliases
  76. half = float16; float = float32; double = float64 # noqa: E702
  77. uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
  78. char = int8; short = int16; int = int32; long = int64 # noqa: E702
  79. # NOTE: these are image dtypes
  80. @staticmethod
  81. def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
  82. @staticmethod
  83. def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
  84. default_float: ClassVar[DType] = float32
  85. default_int: ClassVar[DType] = int32
  86. if (env_default_float := getenv("DEFAULT_FLOAT", "")):
  87. dtypes.default_float = getattr(dtypes, env_default_float.lower())
  88. assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
  89. # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
  90. # we don't support weak type and complex type
  91. promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
  92. dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
  93. dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
  94. dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
  95. @functools.lru_cache(None)
  96. def _get_recursive_parents(dtype:DType) -> Set[DType]:
  97. return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
  98. @functools.lru_cache(None)
  99. def least_upper_dtype(*ds:DType) -> DType:
  100. return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
  101. def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
  102. # HACK: staticmethods are not callable in 3.8 so we have to compare the class
  103. DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'bigint')) or v.__class__ is staticmethod)}
  104. INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
  105. INVERSE_DTYPES_DICT['bigint'] = 'bigint'
  106. def sum_acc_dtype(dt:DType):
  107. # default acc dtype for sum
  108. if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
  109. if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
  110. return least_upper_dtype(dt, dtypes.float)