Bläddra i källkod

tinygrad dynamic config

Alex Cheema 5 månader sedan
förälder
incheckning
1c25375391
2 ändrade filer med 48 tillägg och 25 borttagningar
  1. 47 24
      exo/inference/tinygrad/inference.py
  2. 1 1
      exo/inference/tinygrad/models/llama.py

+ 47 - 24
exo/inference/tinygrad/inference.py

@@ -14,6 +14,7 @@ from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from .stateful_model import StatefulModel
 import asyncio
+import aiofiles
 
 Tensor.no_grad = True
 # default settings
@@ -22,36 +23,60 @@ TOP_K = 25
 TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_P = 0.0
-MODEL_PARAMS = {
-  "1B": {
-    "args": {
-      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
-      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
-    }, "files": 1
-  }, "3B": {
+
+async def get_model_config(model_path: Path) -> dict:
+  config_path = model_path / "config.json"
+  if not config_path.exists():
+    raise ValueError(f"Config file not found at {config_path}")
+
+  async with aiofiles.open(config_path) as f:
+    config = json.loads(await f.read())
+
+  return {
     "args": {
-      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
-      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
-    }, "files": 1
-  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
-  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
-}
+      "dim": config["hidden_size"],
+      "n_heads": config["num_attention_heads"],
+      "n_kv_heads": config.get("num_key_value_heads", config["num_attention_heads"]),
+      "n_layers": config["num_hidden_layers"],
+      "norm_eps": config["rms_norm_eps"],
+      "rope_theta": config.get("rope_theta", 500000),
+      "vocab_size": config["vocab_size"],
+      "hidden_dim": config["intermediate_size"],
+      "rope_scaling": config.get("rope_scaling", None),
+      "tie_word_embeddings": config.get("tie_word_embeddings", False)
+    },
+    "files": config.get("num_shards", 1)
+  }
 
+async def build_transformer(model_path: Path, shard: Shard, device=None):
+  # Get model config from HF config file
+  model_config = await get_model_config(model_path)
 
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+    model = Transformer(**model_config["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
   # load weights
   if model_path.is_dir():
-    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
-    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
-    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+    if (model_path/"model.safetensors.index.json").exists():
+      weights = load(str(model_path/"model.safetensors.index.json"), shard)
+    elif (model_path/"model.safetensors").exists():
+      weights = load(str(model_path/"model.safetensors"), shard)
+    else:
+      weights = concat_weights(
+        [load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(model_config["files"])],
+        device[0] if isinstance(device, tuple) else device
+      )
   else:
     weights = load(str(model_path), shard)
-  weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+
+  weights = convert_from_huggingface(
+    weights,
+    model,
+    model_config["args"]["n_heads"],
+    model_config["args"]["n_kv_heads"]
+  )
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
@@ -76,7 +101,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
     return np.array(tokens)
-  
+
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
@@ -94,11 +119,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
-      loop = asyncio.get_running_loop()
-      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
-      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
+      model_shard = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: asyncio.run(build_transformer(model_path, shard)))
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 
+      self.model = await asyncio.get_running_loop().run_in_executor(self.executor, StatefulModel, model_shard)

+ 1 - 1
exo/inference/tinygrad/models/llama.py

@@ -207,7 +207,7 @@ class Transformer:
     h = x
 
     if cache is None:
-      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]
     for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask, cache=c)