test_uops_stats.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import unittest
  2. from tinygrad import Tensor
  3. from tinygrad.engine.schedule import create_schedule
  4. from tinygrad.engine.realize import lower_schedule_item
  5. from tinygrad.codegen.uops import flops_mem, UOps, UOp
  6. from tinygrad.codegen.uopgraph import UOpGraph
  7. from tinygrad.ops import BinaryOps, TernaryOps
  8. from tinygrad.dtype import dtypes
  9. # TODO: can copy this in here when we remove it
  10. #from tinygrad.ops import get_lazyop_info
  11. #info = get_lazyop_info(ast)
  12. #print(ops, mem, expected_mem)
  13. #print(info.flops, info.mem_estimate)
  14. # **************** new FlopCounter ****************
  15. def get_stats(x:Tensor):
  16. si = create_schedule([x.lazydata])[-1]
  17. ei = lower_schedule_item(si)
  18. return ei.prg.op_estimate, ei.prg.mem_estimate
  19. class TestUOpsStats(unittest.TestCase):
  20. def test_simple_add(self):
  21. a = Tensor.empty(100,100)
  22. b = Tensor.empty(100,100)
  23. c = a+b
  24. ops, mem = get_stats(c)
  25. expected_ops = c.numel()
  26. expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
  27. self.assertEqual(mem, expected_mem)
  28. # NOTE; ops also include indexing ops
  29. assert expected_ops <= ops and ops <= expected_ops * 2
  30. def test_simple_add_sq(self):
  31. a = Tensor.empty(100,100)
  32. b = Tensor.empty(100,100)
  33. c = (a+b)*(a+b)
  34. ops, mem = get_stats(c)
  35. expected_ops = c.numel()*2
  36. expected_mem = a.nbytes() + b.nbytes() + c.nbytes()
  37. self.assertEqual(mem, expected_mem)
  38. # NOTE; ops also include indexing ops
  39. assert expected_ops <= ops and ops <= expected_ops * 2
  40. def test_simple_matmul(self):
  41. a = Tensor.empty(1024,1024)
  42. b = Tensor.empty(1024,1024)
  43. c = a@b
  44. ops, mem = get_stats(c)
  45. expected_ops = c.numel() * 1024 * 2
  46. required_mem = a.nbytes() + b.nbytes() + c.nbytes()
  47. assert expected_ops <= ops and ops <= expected_ops * 1.2
  48. # NOTE: it's hard to assert on the memory here, all depends on caching
  49. assert required_mem <= mem
  50. #MULACC should have the same stats as MUL + ADD
  51. def test_mulacc(self):
  52. globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
  53. o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
  54. o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
  55. u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
  56. u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
  57. u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
  58. u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
  59. u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
  60. uops = UOpGraph([u5])
  61. globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
  62. o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
  63. o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
  64. u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
  65. u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
  66. u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
  67. u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
  68. uops_fma = UOpGraph([u4])
  69. self.assertEqual(flops_mem(uops.uops), flops_mem(uops_fma.uops))
  70. if __name__ == '__main__':
  71. unittest.main(verbosity=2)