external_test_mamba.py 890 B

123456789101112131415161718192021222324
  1. import unittest
  2. from tinygrad.helpers import CI
  3. from examples.mamba import Mamba, generate
  4. from transformers import AutoTokenizer
  5. PROMPT = 'Why is gravity '
  6. TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
  7. @unittest.skipIf(CI, "model is slow for CI")
  8. class TestMamba(unittest.TestCase):
  9. def test_mamba_130M(self):
  10. OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
  11. model = Mamba.from_pretrained('130m')
  12. tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
  13. self.assertEqual(OUT_130M, tinyoutput)
  14. del model
  15. def test_mamba_370M(self):
  16. OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
  17. model = Mamba.from_pretrained('370m')
  18. tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
  19. self.assertEqual(OUT_370M, tinyoutput)
  20. del model
  21. if __name__ == '__main__':
  22. unittest.main()