| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- #!/usr/bin/env python
- import unittest
- import numpy as np
- from tinygrad.tensor import Tensor
- import torch
- def get_question_samp(bsz, seq_len, vocab_size, seed):
- np.random.seed(seed)
- in_ids= np.random.randint(vocab_size, size=(bsz, seq_len))
- mask = np.random.choice([True, False], size=(bsz, seq_len))
- seg_ids = np.random.randint(1, size=(bsz, seq_len))
- return in_ids, mask, seg_ids
- def set_equal_weights(mdl, torch_mdl):
- from tinygrad.nn.state import get_state_dict
- state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
- assert len(state) == len(torch_state)
- for k, v in state.items():
- assert k in torch_state
- torch_state[k].copy_(torch.from_numpy(v.numpy()))
- torch_mdl.eval()
- class TestBert(unittest.TestCase):
- def test_questions(self):
- from extra.models.bert import BertForQuestionAnswering
- from transformers import BertForQuestionAnswering as TorchBertForQuestionAnswering
- from transformers import BertConfig
- # small
- config = {
- 'vocab_size':24, 'hidden_size':2, 'num_hidden_layers':2, 'num_attention_heads':2,
- 'intermediate_size':32, 'hidden_dropout_prob':0.1, 'attention_probs_dropout_prob':0.1,
- 'max_position_embeddings':512, 'type_vocab_size':2
- }
- # Create in tinygrad
- Tensor.manual_seed(1337)
- mdl = BertForQuestionAnswering(**config)
- # Create in torch
- with torch.no_grad():
- torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config))
- set_equal_weights(mdl, torch_mdl)
- seeds = (1337, 3141)
- bsz, seq_len = 1, 16
- for _, seed in enumerate(seeds):
- in_ids, mask, seg_ids = get_question_samp(bsz, seq_len, config['vocab_size'], seed)
- out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids))
- torch_out = torch_mdl.forward(torch.from_numpy(in_ids).long(), torch.from_numpy(mask), torch.from_numpy(seg_ids).long())[:2]
- torch_out = torch.cat(torch_out).unsqueeze(2)
- np.testing.assert_allclose(out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4)
- if __name__ == '__main__':
- unittest.main()
|