fuzz_matmul.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. from tinygrad.helpers import getenv
  3. from tinygrad import dtypes, Tensor
  4. dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
  5. acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
  6. N_START = getenv("N_START", 1)
  7. M_START = getenv("M_START", 1)
  8. K_START = getenv("K_START", 1)
  9. N_STOP = getenv("N_STOP", 32)
  10. M_STOP = getenv("M_STOP", N_STOP)
  11. K_STOP = getenv("K_STOP", N_STOP)
  12. N_STEP = getenv("N_STEP", 1)
  13. M_STEP = getenv("M_STEP", 1)
  14. K_STEP = getenv("K_STEP", 1)
  15. ATOL = getenv("ATOL", 1e-4)
  16. RTOL = getenv("RTOL", 3e-2)
  17. if __name__ == "__main__":
  18. failed = []
  19. for M in range(M_START, M_STOP+1, M_STEP):
  20. for N in range(N_START, N_STOP+1, N_STEP):
  21. for K in range(K_START, K_STOP+1, K_STEP):
  22. print(f"testing {M=} {N=} {K=}")
  23. a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
  24. c = a.matmul(b, acc_dtype=acc_dtype).realize()
  25. comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
  26. nc = c.numpy()
  27. try:
  28. np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
  29. except AssertionError as e:
  30. failed.append((M,N,K,))
  31. if getenv("DEBUG_VALUES") > 0:
  32. indices = np.where(~np.isclose(nc, comp, rtol=RTOL, atol=ATOL))
  33. non_matching_elements_nc = nc[indices]
  34. non_matching_elements_comp = comp[indices]
  35. print(indices)
  36. print("result :", non_matching_elements_nc)
  37. print("ground truth:", non_matching_elements_comp)
  38. print(e)
  39. pass
  40. print(f"failed sizes: {failed}")
  41. print(f"num failures: {len(failed)}")
  42. if len(failed) > 0:
  43. raise RuntimeError(f"failed on {len(failed)} kernels")