external_benchmark_multitensor_allreduce.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import time
  2. from tinygrad import Tensor, Device, GlobalCounters, TinyJit
  3. from tinygrad.lazy import LazyBuffer
  4. from tinygrad.ops import ReduceOps
  5. from tinygrad.multi import MultiLazyBuffer, all_reduce
  6. from tinygrad.engine.schedule import create_schedule
  7. from tinygrad.engine.realize import run_schedule
  8. from tinygrad.helpers import getenv, Context, RING
  9. from typing import List, Union
  10. def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
  11. x = x if isinstance(x, list) else [x]
  12. run_schedule(create_schedule(x))
  13. for lb in x: Device[lb.device].synchronize()
  14. def test(devs: List[str], N: int, iters:int = 10):
  15. def _wrapped(op: ReduceOps, t: Tensor) -> Tensor:
  16. return Tensor(MultiLazyBuffer(all_reduce(op, t.lazydata.lbs), 0), device=devs)
  17. _jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
  18. secs, gflops, gbs = 0, 0, 0
  19. for i in range(-2, iters):
  20. GlobalCounters.reset()
  21. lbs = [Tensor.full((N,), float(1+i), device=d).contiguous().lazydata for i,d in enumerate(devs)]
  22. realize(lbs)
  23. start = time.time()
  24. realize(_jitted(ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
  25. end = time.time()
  26. if i < 0:
  27. # First time is slow due to kernel compilation
  28. continue
  29. i_secs = end-start
  30. i_gflops = GlobalCounters.global_ops/i_secs/10**9
  31. i_gbs = (N*4)/i_secs/10**9
  32. print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
  33. secs += i_secs
  34. gflops += i_gflops
  35. gbs += i_gbs
  36. return (gflops/iters, gbs/iters, secs/iters)
  37. def main():
  38. dev, n_gpus = Device.DEFAULT, getenv("GPUS", 6) # number of gpus
  39. devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
  40. sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
  41. f32 = 4 # 4 bytes
  42. N = sz//f32
  43. print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
  44. with Context(RING=2):
  45. (ring_gflops, ring_gbs, ring_secs) = test(devs, N)
  46. with Context(RING=0):
  47. (naive_gflops, naive_gbs, naive_secs) = test(devs, N)
  48. print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
  49. print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
  50. if __name__ == "__main__":
  51. main()