|
@@ -7,7 +7,7 @@ from tiktoken.load import load_tiktoken_bpe
|
|
|
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
|
|
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
|
|
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
|
|
|
-from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
|
|
|
+from tinygrad.helpers import DEBUG, tqdm, _cache_dir
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
import numpy as np
|
|
@@ -186,7 +186,35 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
return
|
|
|
|
|
|
model_path = Path(shard.model_id)
|
|
|
- size = "8B" # one of 8B or 70B for now
|
|
|
+
|
|
|
+ models_dir = Path(_cache_dir) / "downloads"
|
|
|
+ model_path = models_dir / shard.model_id
|
|
|
+ if model_path.exists():
|
|
|
+ model = model_path
|
|
|
+ else:
|
|
|
+ from tinygrad.helpers import fetch
|
|
|
+
|
|
|
+ if DEBUG >= 2: print(f"Fetching configuration for model {shard.model_id}...")
|
|
|
+ if shard.model_id == "llama3-8b-sfr":
|
|
|
+ fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
|
|
|
+ fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
|
|
|
+ size = "8B"
|
|
|
+ elif shard.model_id == "llama3-70b-sfr":
|
|
|
+ raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
|
|
|
+ # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
|
|
|
+ # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
|
|
|
+ # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
|
|
|
+ # size = "70B"
|
|
|
+ else:
|
|
|
+ raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
|
|
|
+
|
|
|
model = build_transformer(model_path, shard=shard, model_size=size)
|
|
|
tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
|
|
|
|