helpers.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import sys, unittest
  2. import numpy as np
  3. from tinygrad import Tensor, Device, dtypes
  4. from tinygrad.codegen.uops import UOp
  5. from tinygrad.tensor import _to_np_dtype
  6. from tinygrad.engine.realize import Runner
  7. from tinygrad.dtype import DType
  8. from tinygrad.nn.state import get_parameters
  9. from tinygrad.helpers import Context, CI, OSX, getenv
  10. def derandomize_model(model):
  11. with Context(GRAPH=0):
  12. for p in get_parameters(model):
  13. p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
  14. p.realize()
  15. def assert_jit_cache_len(fxn, expected_len):
  16. if not fxn.jit_cache:
  17. assert expected_len == 0, expected_len
  18. return
  19. # until we have a better way of typing the prg in ExecItem
  20. if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
  21. assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache)
  22. else:
  23. assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
  24. # until we have a better way of typing the prg in ExecItem
  25. assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
  26. assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
  27. def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
  28. if dtype == dtypes.bigint and device != "PYTHON": return False
  29. if dtype == dtypes.bfloat16:
  30. # NOTE: this requires bf16 buffer support
  31. return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
  32. if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
  33. # for CI GPU and OSX, cl_khr_fp16 isn't supported
  34. # for CI LLVM, it segfaults because it can't link to the casting function
  35. # CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
  36. # PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
  37. if dtype == dtypes.half:
  38. if device == "GPU": return not CI and not OSX
  39. if device in ["LLVM", "CUDA", "NV"]: return not CI
  40. if device == "PYTHON": return sys.version_info >= (3, 12)
  41. if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
  42. return True
  43. def rand_for_dtype(dt:DType, size:int):
  44. if dtypes.is_unsigned(dt):
  45. return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt))
  46. elif dtypes.is_int(dt):
  47. return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt))
  48. elif dt == dtypes.bool:
  49. return np.random.choice([True, False], size=size)
  50. return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
  51. class TestUOps(unittest.TestCase):
  52. def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
  53. # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
  54. self.assertIs(uop1.op, uop2.op)
  55. self.assertEqual(uop1.dtype, uop2.dtype)
  56. self.assertEqual(uop1.arg, uop2.arg)
  57. self.assertEqual(len(uop1.src), len(uop2.src))
  58. for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2)