external_llama_eval.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from lm_eval.base import BaseLM
  2. from lm_eval import evaluator, tasks
  3. import torch, json, argparse
  4. from examples.llama import LLaMa
  5. from tinygrad.tensor import Tensor
  6. from tinygrad import Device
  7. class LLaMaAdaptor(BaseLM):
  8. def __init__(
  9. self,
  10. model_size="7B",
  11. model_gen=1,
  12. device="",
  13. quantize=False,
  14. batch_size=1,
  15. max_batch_size=1,
  16. do_sample=False,
  17. temperature=1.0,
  18. checkpoint_path="",
  19. tokenizer_path="",
  20. ):
  21. super().__init__()
  22. if batch_size is None:
  23. batch_size = 1
  24. self.do_sample = do_sample
  25. self.temperature = temperature
  26. self._device = device
  27. assert isinstance(model_gen, int)
  28. assert isinstance(model_size, str)
  29. assert isinstance(batch_size, int)
  30. assert isinstance(checkpoint_path, str)
  31. assert isinstance(tokenizer_path, str)
  32. self.llama = LLaMa.build(checkpoint_path, tokenizer_path, model_gen, model_size, quantize)
  33. @classmethod
  34. def create_from_arg_string(cls, arg_string, additional_config=None):
  35. kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")}
  36. return cls(**kwargs, **additional_config)
  37. @property
  38. def eot_token_id(self):
  39. # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
  40. return self.llama.tokenizer.eos_id()
  41. @property
  42. def max_length(self):
  43. return 1024
  44. @property
  45. def max_gen_toks(self):
  46. return 256
  47. @property
  48. def batch_size(self):
  49. return 1
  50. @property
  51. def device(self):
  52. return self._device
  53. def tok_encode(self, string: str):
  54. return [self.llama.tokenizer.bos_id()] + self.llama.tokenizer.encode(string)
  55. def tok_decode(self, tokens):
  56. return self.llama.tokenizer.decode(tokens)
  57. def _model_call(self, inps):
  58. Tensor.no_grad = True
  59. return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy())
  60. def greedy_until(self, requests):
  61. continuations = []
  62. for request in requests:
  63. prompt, until = request[0], request[1]['until']
  64. output = self.llama.greedy_until(prompt, until, max_length=128, temperature=0.0)
  65. continuations.append(output[len(prompt):])
  66. return continuations
  67. def _model_generate(self, context, max_length, eos_token_id):
  68. raise NotImplementedError()
  69. if __name__ == '__main__':
  70. print(f"using {Device.DEFAULT} backend")
  71. parser = argparse.ArgumentParser(description='Run LLaMA evals in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  72. parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
  73. parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
  74. parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
  75. parser.add_argument('--eval', type=str, default="arc_easy", help="Run in evaluation mode")
  76. parser.add_argument('--limit', type=int, default=None, help="Limit tests in eval")
  77. parser.add_argument('--weights', type=str, default="./weights/LLaMa/", help="Location of the weights")
  78. parser.add_argument('--tokenizer', type=str, default="./weights/LLaMa/tokenizer.model", help="Location of the tokenizer")
  79. args = parser.parse_args()
  80. # run eval and exit
  81. adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize,
  82. checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu")
  83. results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit)
  84. print(json.dumps(results, indent=2))