tvm_gemm.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te
  2. M, N, K = 1024, 1024, 1024
  3. try:
  4. import tvm
  5. from tvm import te
  6. #print(tvm.target.Target.list_kinds())
  7. # c, opencl
  8. target = tvm.target.Target(target="c")
  9. # TVM Matrix Multiplication using TE
  10. k = te.reduce_axis((0, K), "k")
  11. A = te.placeholder((M, K), name="A")
  12. B = te.placeholder((K, N), name="B")
  13. C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
  14. # Default schedule
  15. s = te.create_schedule(C.op)
  16. #print(tvm.lower(s, [A, B, C], simple_mode=True))
  17. # Output C code
  18. func = tvm.build(s, [A, B, C], target=target, name="mmult")
  19. print(func.get_source())
  20. except ImportError:
  21. print("** please install TVM for TVM output")
  22. # tinygrad version
  23. import os
  24. from tinygrad.tensor import Tensor
  25. from tinygrad.engine.schedule import create_schedule
  26. # define the compute
  27. A = Tensor.rand(M, K, device="clang")
  28. B = Tensor.rand(K, N, device="clang")
  29. C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
  30. sched = create_schedule([C.lazydata])
  31. from tinygrad.codegen.kernel import Kernel
  32. from tinygrad.device import CompilerOptions
  33. lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
  34. #lin.hand_coded_optimizations()
  35. lin.linearize()
  36. from tinygrad.runtime.ops_clang import renderer
  37. src = renderer("mmult", lin.uops)
  38. print(src)