test_rnnt.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. from tinygrad.tensor import Tensor
  5. from extra.models.rnnt import LSTM
  6. import torch
  7. class TestRNNT(unittest.TestCase):
  8. def test_lstm(self):
  9. BS, SQ, IS, HS, L = 2, 20, 40, 128, 2
  10. # create in torch
  11. with torch.no_grad():
  12. torch_layer = torch.nn.LSTM(IS, HS, L)
  13. # create in tinygrad
  14. layer = LSTM(IS, HS, L, 0.0)
  15. # copy weights
  16. with torch.no_grad():
  17. layer.cells[0].weights_ih.assign(Tensor(torch_layer.weight_ih_l0.numpy()))
  18. layer.cells[0].weights_hh.assign(Tensor(torch_layer.weight_hh_l0.numpy()))
  19. layer.cells[0].bias_ih.assign(Tensor(torch_layer.bias_ih_l0.numpy()))
  20. layer.cells[0].bias_hh.assign(Tensor(torch_layer.bias_hh_l0.numpy()))
  21. layer.cells[1].weights_ih.assign(Tensor(torch_layer.weight_ih_l1.numpy()))
  22. layer.cells[1].weights_hh.assign(Tensor(torch_layer.weight_hh_l1.numpy()))
  23. layer.cells[1].bias_ih.assign(Tensor(torch_layer.bias_ih_l1.numpy()))
  24. layer.cells[1].bias_hh.assign(Tensor(torch_layer.bias_hh_l1.numpy()))
  25. # test initial hidden
  26. for _ in range(3):
  27. x = Tensor.randn(SQ, BS, IS)
  28. z, hc = layer(x, None)
  29. torch_x = torch.tensor(x.numpy())
  30. torch_z, torch_hc = torch_layer(torch_x)
  31. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
  32. # test passing hidden
  33. for _ in range(3):
  34. x = Tensor.randn(SQ, BS, IS)
  35. z, hc = layer(x, hc)
  36. torch_x = torch.tensor(x.numpy())
  37. torch_z, torch_hc = torch_layer(torch_x, torch_hc)
  38. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
  39. if __name__ == '__main__':
  40. unittest.main()