simple_matvec.py 1.1 KB

123456789101112131415161718192021222324252627282930
  1. import numpy as np
  2. from tinygrad.helpers import getenv
  3. from tinygrad import dtypes, Tensor, Device
  4. dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
  5. acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
  6. GPUS = getenv("GPUS", 0)
  7. M = getenv("M", 16384)
  8. N = getenv("N", 4096)
  9. CNT = getenv("CNT", 10)
  10. ATOL = getenv("ATOL", 1e-4)
  11. RTOL = getenv("RTOL", 3e-2)
  12. def _rand(device):
  13. a, b = Tensor.rand(M, N, dtype=dtype_in).realize(), Tensor.rand(N, dtype=dtype_in).realize()
  14. if isinstance(device, tuple):
  15. a.shard_(device, axis=1)
  16. b.shard_(device, axis=0)
  17. return a, b
  18. if __name__ == "__main__":
  19. device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS)) if GPUS > 1 else Device.DEFAULT
  20. a, b = _rand(device)
  21. for i in range(CNT):
  22. if i > 0 and getenv("RAND", 0) != 0:
  23. a, b = _rand(device)
  24. c = a.matmul(b, acc_dtype=acc_dtype).realize()
  25. nc = c.numpy()
  26. comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
  27. np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)