1
0

inference.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from pathlib import Path
  2. import json
  3. import os
  4. from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
  5. from exo.inference.shard import Shard
  6. from tinygrad.nn.state import safe_load, torch_load, load_state_dict
  7. from tinygrad import Tensor, dtypes, nn, Context
  8. from transformers import AutoTokenizer
  9. from exo.inference.inference_engine import InferenceEngine
  10. from typing import Optional, Tuple
  11. import numpy as np
  12. from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
  13. from exo.download.shard_download import ShardDownloader
  14. Tensor.no_grad = True
  15. # default settings
  16. TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
  17. TOP_K = 25
  18. TOP_P = 0.9
  19. ALPHA_F = 0.1
  20. ALPHA_P = 0.0
  21. MODEL_PARAMS = {
  22. "8B": {
  23. "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
  24. "files": 1
  25. },
  26. "70B": {
  27. "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
  28. "files": 8
  29. }
  30. }
  31. def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
  32. # build model
  33. linear = nn.Linear
  34. with Context(THREEFRY=0):
  35. model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
  36. # load weights
  37. if model_path.is_dir():
  38. if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
  39. elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
  40. else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  41. else:
  42. weights = load(str(model_path), shard)
  43. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
  44. weights = fix_bf16(weights)
  45. with Context(BEAM=0):
  46. # replace weights in model
  47. load_state_dict(model, weights, strict=False, consume=False) # consume=True
  48. return model
  49. class TinygradDynamicShardInferenceEngine(InferenceEngine):
  50. def __init__(self, shard_downloader: ShardDownloader):
  51. self.shard = None
  52. self.shard_downloader = shard_downloader
  53. async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  54. await self.ensure_shard(shard)
  55. start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
  56. n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
  57. toks = self.tokenizer.encode(prompt)
  58. h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
  59. if h.shape == (1,):
  60. start_pos += len(toks)
  61. start_pos += 1
  62. n_captured_toks = 0
  63. return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
  64. else:
  65. n_captured_toks = len(toks)
  66. return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
  67. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
  68. await self.ensure_shard(shard)
  69. start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
  70. n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
  71. h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
  72. if h.shape == (1,):
  73. start_pos += n_captured_toks
  74. start_pos += 1
  75. n_captured_toks = 0
  76. return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
  77. else:
  78. return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
  79. async def ensure_shard(self, shard: Shard):
  80. if self.shard == shard:
  81. return
  82. model_path = await self.shard_downloader.ensure_shard(shard)
  83. self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
  84. self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
  85. self.shard = shard