inference.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from pathlib import Path
  2. import json
  3. import os
  4. from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
  5. from exo.inference.shard import Shard
  6. from exo.inference.tokenizers import resolve_tokenizer
  7. from tinygrad.nn.state import load_state_dict
  8. from tinygrad import Tensor, nn, Context, TinyJit
  9. from exo.inference.inference_engine import InferenceEngine
  10. import numpy as np
  11. from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
  12. from exo.download.shard_download import ShardDownloader
  13. from concurrent.futures import ThreadPoolExecutor
  14. from .stateful_model import make_prompt_state
  15. from .losses import length_masked_ce_loss
  16. from collections import OrderedDict
  17. import asyncio
  18. Tensor.no_grad = True
  19. # default settings
  20. TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
  21. TOP_K = 25
  22. TOP_P = 0.9
  23. ALPHA_F = 0.1
  24. ALPHA_P = 0.0
  25. MODEL_PARAMS = {
  26. "1B": {
  27. "args": {
  28. "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
  29. "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
  30. }, "files": 1
  31. }, "3B": {
  32. "args": {
  33. "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
  34. "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
  35. }, "files": 1
  36. }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
  37. "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
  38. }
  39. def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
  40. # build model
  41. linear = nn.Linear
  42. model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
  43. # load weights
  44. if model_path.is_dir():
  45. if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
  46. elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
  47. else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  48. else:
  49. weights = load(str(model_path), shard)
  50. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
  51. weights = fix_bf16(weights)
  52. with Context(BEAM=0):
  53. # replace weights in model
  54. load_state_dict(model, weights, strict=False, consume=False) # consume=True
  55. return model
  56. class TinygradDynamicShardInferenceEngine(InferenceEngine):
  57. def __init__(self, shard_downloader: ShardDownloader):
  58. self.shard = None
  59. self.shard_downloader = shard_downloader
  60. self.executor = ThreadPoolExecutor(max_workers=1)
  61. self.states = OrderedDict()
  62. def poll_state(self, x, request_id: str, max_states=2):
  63. if request_id not in self.states:
  64. if len(self.states) >= max_states:
  65. self.states.popitem(last=False)
  66. self.states[request_id] = make_prompt_state(x, self.model, self.shard)
  67. else:
  68. self.states.move_to_end(request_id)
  69. state = self.states[request_id]
  70. return {"start_pos": state.start, "cache": state.cache}
  71. async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
  72. logits = x[:, -1, :]
  73. def sample_wrapper():
  74. return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
  75. return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
  76. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  77. await self.ensure_shard(shard)
  78. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
  79. return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
  80. async def decode(self, shard: Shard, tokens) -> str:
  81. await self.ensure_shard(shard)
  82. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
  83. return tokens
  84. async def load_checkpoint(self, shard: Shard, path: str):
  85. await self.ensure_shard(shard)
  86. async def save_checkpoint(self, shard: Shard, path: str):
  87. await self.ensure_shard(shard)
  88. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  89. await self.ensure_shard(shard)
  90. def wrap_infer():
  91. x = Tensor(input_data)
  92. state = self.poll_state(x, request_id)
  93. out = self.model(x, **state)
  94. self.states[request_id].start += x.shape[1]
  95. return out.realize()
  96. output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
  97. return output_data.numpy()
  98. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
  99. def step(x, y, l):
  100. Tensor.training = False
  101. return self.session['loss'](self.model, x, y, l)
  102. await self.ensure_shard(shard)
  103. score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
  104. out = score.numpy()
  105. return out
  106. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=nn.optim.Adam, lr=1e-5):
  107. def step(x, y, l):
  108. Tensor.training = True
  109. score = self.session['loss'](self.model, x, y, l)
  110. self.session['opt'].zero_grad()
  111. score.backward()
  112. self.session['opt'].step()
  113. return score
  114. await self.ensure_shard(shard)
  115. score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
  116. return loss.numpy(), loss.numpy()
  117. async def ensure_shard(self, shard: Shard):
  118. if self.shard == shard:
  119. return
  120. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  121. if self.shard != shard:
  122. loop = asyncio.get_running_loop()
  123. parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
  124. model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
  125. tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
  126. self.tokenizer = await resolve_tokenizer(tokenizer_path)
  127. self.shard = shard
  128. self.model = model_shard