inference.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from pathlib import Path
  2. import json
  3. import os
  4. from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
  5. from exo.inference.shard import Shard
  6. from exo.inference.tokenizers import resolve_tokenizer
  7. from tinygrad.nn.state import load_state_dict
  8. from tinygrad import Tensor, nn, Context, TinyJit
  9. from exo.inference.inference_engine import InferenceEngine
  10. import numpy as np
  11. from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
  12. from exo.download.shard_download import ShardDownloader
  13. from concurrent.futures import ThreadPoolExecutor
  14. from .stateful_model import StatefulModel
  15. from .losses import length_masked_ce_loss
  16. import asyncio
  17. Tensor.no_grad = False
  18. # default settings
  19. TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
  20. TOP_K = 25
  21. TOP_P = 0.9
  22. ALPHA_F = 0.1
  23. ALPHA_P = 0.0
  24. MODEL_PARAMS = {
  25. "1B": {
  26. "args": {
  27. "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
  28. "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
  29. }, "files": 1
  30. }, "3B": {
  31. "args": {
  32. "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
  33. "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
  34. }, "files": 1
  35. }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
  36. "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
  37. }
  38. def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
  39. # build model
  40. linear = nn.Linear
  41. model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
  42. # load weights
  43. if model_path.is_dir():
  44. if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
  45. elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
  46. else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  47. else:
  48. weights = load(str(model_path), shard)
  49. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
  50. weights = fix_bf16(weights)
  51. with Context(BEAM=0):
  52. # replace weights in model
  53. load_state_dict(model, weights, strict=False, consume=False) # consume=True
  54. return model
  55. class TinygradDynamicShardInferenceEngine(InferenceEngine):
  56. def __init__(self, shard_downloader: ShardDownloader):
  57. self.shard = None
  58. self.shard_downloader = shard_downloader
  59. self.executor = ThreadPoolExecutor(max_workers=1)
  60. self.session = {}
  61. async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
  62. logits = x[:, -1, :]
  63. def sample_wrapper():
  64. return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
  65. return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
  66. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  67. await self.ensure_shard(shard)
  68. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
  69. return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
  70. async def decode(self, shard: Shard, tokens) -> str:
  71. await self.ensure_shard(shard)
  72. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
  73. return tokens
  74. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  75. await self.ensure_shard(shard)
  76. output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
  77. return output_data.numpy()
  78. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
  79. def step(x, y, l):
  80. Tensor.training = False
  81. return self.session['loss'](self.model, x, y, l)
  82. await self.ensure_shard(shard)
  83. await self.ensure_session('loss', lambda: loss)
  84. await self.ensure_session('jit', lambda: TinyJit(step))
  85. score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
  86. out = score.numpy()
  87. return out
  88. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=nn.optim.Adam, lr=1e-5):
  89. def step(x, y, l):
  90. Tensor.training = True
  91. score = self.session['loss'](self.model, x, y, l)
  92. self.session['opt'].zero_grad()
  93. score.backward()
  94. self.session['opt'].step()
  95. return score
  96. await self.ensure_shard(shard)
  97. await self.ensure_session('loss', lambda: loss)
  98. await self.ensure_session('opt', lambda: opt(nn.state.get_parameters(self.model.model), lr=lr))
  99. await self.ensure_session('jit', lambda: TinyJit(step))
  100. score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
  101. return loss.numpy(), loss.numpy()
  102. async def ensure_shard(self, shard: Shard):
  103. if self.shard == shard:
  104. return
  105. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  106. if self.shard != shard:
  107. loop = asyncio.get_running_loop()
  108. parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
  109. model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
  110. tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
  111. self.tokenizer = await resolve_tokenizer(tokenizer_path)
  112. self.shard = shard
  113. self.model = model_shard