test_custom_function.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # this is an example of how you can write terrible DSP compute breaking ops like warpPerspective
  2. # here we use a CUSTOM op to write atan2
  3. import unittest
  4. import numpy as np
  5. from typing import Optional, Tuple
  6. from tinygrad.helpers import prod
  7. from tinygrad.dtype import dtypes
  8. # *** first, we implement the atan2 op at the lowest level ***
  9. # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
  10. from tinygrad.lazy import Buffer, create_lazybuffer
  11. from tinygrad.device import Device
  12. from tinygrad.shape.shapetracker import ShapeTracker
  13. from tinygrad.engine.realize import CompiledRunner
  14. from tinygrad.renderer import Program
  15. # we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
  16. def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
  17. assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
  18. src = """
  19. __kernel void atan2_gpu(global float *c, global float *a, global float *b) {
  20. int idx = get_global_id(0);
  21. c[idx] = atan2(a[idx], b[idx]);
  22. }"""
  23. CompiledRunner(Program("atan2_gpu", src, ret.device, global_size=[ret.size,1,1])).exec([ret, a, b])
  24. def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
  25. # *** second, we write the ATan2 mlop ***
  26. # NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
  27. # In general, it is also optional to write a backward function, just your backward pass won't work without it
  28. from tinygrad.ops import MetaOps, BinaryOps, UnaryOps
  29. from tinygrad.lazy import LazyBuffer
  30. from tinygrad.tensor import Function
  31. class ATan2(Function):
  32. def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
  33. assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
  34. self.a, self.b = a, b
  35. return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), max(a.dtype, b.dtype), MetaOps.CUSTOM,
  36. arg={"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device], srcs=(a.contiguous(), b.contiguous()))
  37. def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
  38. recip = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)).e(UnaryOps.RECIP)
  39. return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.MUL, recip)) if self.needs_input_grad[0] else None, \
  40. grad_output.e(BinaryOps.MUL, self.a.const(0).e(BinaryOps.ADD, self.a.e(UnaryOps.NEG)).e(BinaryOps.MUL, recip)) \
  41. if self.needs_input_grad[1] else None
  42. # *** third, we use our lovely new mlop in some tests ***
  43. from tinygrad.tensor import Tensor
  44. @unittest.skipUnless(Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU")
  45. class TestCustomFunction(unittest.TestCase):
  46. def test_atan2_forward(self):
  47. # create some random Tensors, permute them just because we can
  48. a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  49. b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  50. # run the forward pass. note: up until the .numpy(), it's all lazy
  51. c = ATan2.apply(a, b)
  52. print(c.numpy())
  53. # check the forward pass (in numpy)
  54. np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5)
  55. # fun fact, this never actually calls forward, so it works in all the backends
  56. def test_atan2_backward(self):
  57. # have to go forward before we can go backward
  58. a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  59. b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  60. c = ATan2.apply(a, b)
  61. # run the backward pass
  62. c.mean().backward()
  63. assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients"
  64. print(a.grad.numpy())
  65. print(b.grad.numpy())
  66. # check the backward pass (in torch)
  67. import torch
  68. ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True)
  69. tc = torch.atan2(ta, tb)
  70. tc.mean().backward()
  71. assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients"
  72. np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5)
  73. np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5)
  74. @unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
  75. def test_atan2_jit(self):
  76. # custom ops even work in the JIT!
  77. from tinygrad.engine.jit import TinyJit
  78. @TinyJit
  79. def jitted_atan2(a:Tensor, b:Tensor) -> Tensor:
  80. return ATan2.apply(a, b).realize()
  81. for _ in range(5):
  82. a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  83. b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
  84. c = jitted_atan2(a, b)
  85. np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5)
  86. if __name__ == "__main__":
  87. unittest.main()