1
0

test_bert.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. from tinygrad.tensor import Tensor
  5. import torch
  6. def get_question_samp(bsz, seq_len, vocab_size, seed):
  7. np.random.seed(seed)
  8. in_ids= np.random.randint(vocab_size, size=(bsz, seq_len))
  9. mask = np.random.choice([True, False], size=(bsz, seq_len))
  10. seg_ids = np.random.randint(1, size=(bsz, seq_len))
  11. return in_ids, mask, seg_ids
  12. def set_equal_weights(mdl, torch_mdl):
  13. from tinygrad.nn.state import get_state_dict
  14. state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
  15. assert len(state) == len(torch_state)
  16. for k, v in state.items():
  17. assert k in torch_state
  18. torch_state[k].copy_(torch.from_numpy(v.numpy()))
  19. torch_mdl.eval()
  20. class TestBert(unittest.TestCase):
  21. def test_questions(self):
  22. from extra.models.bert import BertForQuestionAnswering
  23. from transformers import BertForQuestionAnswering as TorchBertForQuestionAnswering
  24. from transformers import BertConfig
  25. # small
  26. config = {
  27. 'vocab_size':24, 'hidden_size':2, 'num_hidden_layers':2, 'num_attention_heads':2,
  28. 'intermediate_size':32, 'hidden_dropout_prob':0.1, 'attention_probs_dropout_prob':0.1,
  29. 'max_position_embeddings':512, 'type_vocab_size':2
  30. }
  31. # Create in tinygrad
  32. Tensor.manual_seed(1337)
  33. mdl = BertForQuestionAnswering(**config)
  34. # Create in torch
  35. with torch.no_grad():
  36. torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config))
  37. set_equal_weights(mdl, torch_mdl)
  38. seeds = (1337, 3141)
  39. bsz, seq_len = 1, 16
  40. for _, seed in enumerate(seeds):
  41. in_ids, mask, seg_ids = get_question_samp(bsz, seq_len, config['vocab_size'], seed)
  42. out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids))
  43. torch_out = torch_mdl.forward(torch.from_numpy(in_ids).long(), torch.from_numpy(mask), torch.from_numpy(seg_ids).long())[:2]
  44. torch_out = torch.cat(torch_out).unsqueeze(2)
  45. np.testing.assert_allclose(out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4)
  46. if __name__ == '__main__':
  47. unittest.main()