test_flopcounter.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #!/usr/bin/env python
  2. import unittest
  3. from tinygrad import dtypes, Tensor
  4. from tinygrad.helpers import prod
  5. from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
  6. from tinygrad.shape.shapetracker import ShapeTracker
  7. from tinygrad.codegen.kernel import Kernel
  8. from tinygrad.codegen.uops import flops_mem
  9. class TestFlopCounter(unittest.TestCase):
  10. def setUp(self):
  11. self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
  12. self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
  13. self.buf2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,4))))
  14. def compare_flop_counters(self, ast):
  15. info = get_lazyop_info(ast.src[0])
  16. lin = Kernel(ast)
  17. # NOTE: why does hand coded optimizations change flops for the GEMM?
  18. #lin.hand_coded_optimizations()
  19. lin.linearize()
  20. ops, mem = flops_mem(lin.uops.uops, ignore_indexing=True)
  21. run_count = prod((lin.global_size or []) + (lin.local_size or []))
  22. self.assertEqual(info.flops, ops*run_count)
  23. print(info.flops, info.mem_estimate, "vs", ops*run_count, mem*run_count)
  24. #lin.uops.print()
  25. def test_flops_sin(self):
  26. op0 = LazyOp(UnaryOps.SIN, (self.buf0,), None)
  27. info = get_lazyop_info(op0)
  28. self.assertEqual(info.flops, 4)
  29. def test_flops_add(self):
  30. op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
  31. info = get_lazyop_info(op0)
  32. self.assertEqual(info.flops, 4)
  33. def test_flops_add_twice(self):
  34. op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
  35. op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
  36. info = get_lazyop_info(op1)
  37. self.assertEqual(info.flops, 8)
  38. def test_flops_add_self(self):
  39. op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
  40. op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
  41. info = get_lazyop_info(op1)
  42. self.assertEqual(info.flops, 8)
  43. def test_flops_add_roundabout_self(self):
  44. op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
  45. op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
  46. op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
  47. info = get_lazyop_info(op2)
  48. self.assertEqual(info.flops, 12)
  49. def test_flops_red(self):
  50. op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
  51. op1 = LazyOp(ReduceOps.SUM, (op0,), (0,))
  52. op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
  53. info = get_lazyop_info(op2)
  54. self.assertEqual(info.flops, 9)
  55. def test_flops_sum1d(self):
  56. op0 = LazyOp(ReduceOps.SUM, (self.buf0,), (0,))
  57. info = get_lazyop_info(op0)
  58. self.assertEqual(info.flops, 4)
  59. self.assertEqual(info.shape, (1,))
  60. def test_flops_sum2d(self):
  61. op0 = LazyOp(ReduceOps.SUM, (self.buf2,), (0,))
  62. info = get_lazyop_info(op0)
  63. self.assertEqual(info.flops, 16)
  64. self.assertEqual(info.shape, (1,4))
  65. op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
  66. info = get_lazyop_info(op1)
  67. self.assertEqual(info.flops, 16+4)
  68. self.assertEqual(info.shape, (1,1))
  69. def test_flops_conv(self):
  70. out = Tensor.empty(16,3,16,16).conv2d(Tensor.empty(64,3,3,3))
  71. self.compare_flop_counters(out.schedule()[-1].ast)
  72. def test_flops_gemm(self):
  73. out = Tensor.empty(4,16,16) @ Tensor.empty(4,16,16)
  74. self.compare_flop_counters(out.schedule()[-1].ast)
  75. if __name__ == '__main__':
  76. unittest.main()