simple_matmul.py 1.3 KB

12345678910111213141516171819202122232425262728293031
  1. import numpy as np
  2. from tinygrad.helpers import getenv
  3. from tinygrad import dtypes, Tensor
  4. dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
  5. acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
  6. N = getenv("N", 4096)
  7. M = getenv("M", N)
  8. K = getenv("K", N)
  9. CNT = getenv("CNT", 10)
  10. ATOL = getenv("ATOL", 1e-4)
  11. RTOL = getenv("RTOL", 3e-2)
  12. if __name__ == "__main__":
  13. a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
  14. for i in range(CNT):
  15. if i > 0 and getenv("RAND", 0) != 0:
  16. a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
  17. c = a.matmul(b, acc_dtype=acc_dtype).realize()
  18. comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
  19. nc = c.numpy()
  20. try:
  21. np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
  22. except AssertionError as e:
  23. if getenv("DEBUG_VALUES") > 0:
  24. indices = np.where(~np.isclose(nc, comp, rtol=RTOL, atol=ATOL))
  25. non_matching_elements_nc = nc[indices]
  26. non_matching_elements_comp = comp[indices]
  27. print(indices)
  28. print("result :", non_matching_elements_nc)
  29. print("ground truth:", non_matching_elements_comp)
  30. raise e