|
@@ -14,6 +14,7 @@ from exo.download.shard_download import ShardDownloader
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from .stateful_model import StatefulModel
|
|
|
import asyncio
|
|
|
+import aiofiles
|
|
|
|
|
|
Tensor.no_grad = True
|
|
|
# default settings
|
|
@@ -22,36 +23,60 @@ TOP_K = 25
|
|
|
TOP_P = 0.9
|
|
|
ALPHA_F = 0.1
|
|
|
ALPHA_P = 0.0
|
|
|
-MODEL_PARAMS = {
|
|
|
- "1B": {
|
|
|
- "args": {
|
|
|
- "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
|
|
|
- "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
|
|
|
- }, "files": 1
|
|
|
- }, "3B": {
|
|
|
+
|
|
|
+async def get_model_config(model_path: Path) -> dict:
|
|
|
+ config_path = model_path / "config.json"
|
|
|
+ if not config_path.exists():
|
|
|
+ raise ValueError(f"Config file not found at {config_path}")
|
|
|
+
|
|
|
+ async with aiofiles.open(config_path) as f:
|
|
|
+ config = json.loads(await f.read())
|
|
|
+
|
|
|
+ return {
|
|
|
"args": {
|
|
|
- "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
|
|
|
- "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
|
|
|
- }, "files": 1
|
|
|
- }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
|
|
|
- "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
|
|
|
-}
|
|
|
+ "dim": config["hidden_size"],
|
|
|
+ "n_heads": config["num_attention_heads"],
|
|
|
+ "n_kv_heads": config.get("num_key_value_heads", config["num_attention_heads"]),
|
|
|
+ "n_layers": config["num_hidden_layers"],
|
|
|
+ "norm_eps": config["rms_norm_eps"],
|
|
|
+ "rope_theta": config.get("rope_theta", 500000),
|
|
|
+ "vocab_size": config["vocab_size"],
|
|
|
+ "hidden_dim": config["intermediate_size"],
|
|
|
+ "rope_scaling": config.get("rope_scaling", None),
|
|
|
+ "tie_word_embeddings": config.get("tie_word_embeddings", False)
|
|
|
+ },
|
|
|
+ "files": config.get("num_shards", 1)
|
|
|
+ }
|
|
|
|
|
|
+async def build_transformer(model_path: Path, shard: Shard, device=None):
|
|
|
+ # Get model config from HF config file
|
|
|
+ model_config = await get_model_config(model_path)
|
|
|
|
|
|
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
|
|
|
# build model
|
|
|
linear = nn.Linear
|
|
|
with Context(THREEFRY=0):
|
|
|
- model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
|
|
|
+ model = Transformer(**model_config["args"], linear=linear, max_context=8192, jit=True, shard=shard)
|
|
|
|
|
|
# load weights
|
|
|
if model_path.is_dir():
|
|
|
- if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
|
|
|
- elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
|
|
|
- else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
|
|
|
+ if (model_path/"model.safetensors.index.json").exists():
|
|
|
+ weights = load(str(model_path/"model.safetensors.index.json"), shard)
|
|
|
+ elif (model_path/"model.safetensors").exists():
|
|
|
+ weights = load(str(model_path/"model.safetensors"), shard)
|
|
|
+ else:
|
|
|
+ weights = concat_weights(
|
|
|
+ [load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(model_config["files"])],
|
|
|
+ device[0] if isinstance(device, tuple) else device
|
|
|
+ )
|
|
|
else:
|
|
|
weights = load(str(model_path), shard)
|
|
|
- weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
|
|
+
|
|
|
+ weights = convert_from_huggingface(
|
|
|
+ weights,
|
|
|
+ model,
|
|
|
+ model_config["args"]["n_heads"],
|
|
|
+ model_config["args"]["n_kv_heads"]
|
|
|
+ )
|
|
|
weights = fix_bf16(weights)
|
|
|
|
|
|
with Context(BEAM=0):
|
|
@@ -76,7 +101,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
return np.array(tokens)
|
|
|
-
|
|
|
+
|
|
|
async def decode(self, shard: Shard, tokens) -> str:
|
|
|
await self.ensure_shard(shard)
|
|
|
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
|
@@ -94,11 +119,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
|
|
|
if self.shard != shard:
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
- parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
|
|
|
- model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
|
|
|
+ model_shard = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: asyncio.run(build_transformer(model_path, shard)))
|
|
|
|
|
|
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
|
|
self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
|
|
self.shard = shard
|
|
|
- self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)
|
|
|
+ self.model = await asyncio.get_running_loop().run_in_executor(self.executor, StatefulModel, model_shard)
|