Prechádzať zdrojové kódy

fetch model if doesnt exist on tinygrad

Alex Cheema 1 rok pred
rodič
commit
6b3727f023
1 zmenil súbory, kde vykonal 30 pridanie a 2 odobranie
  1. 30 2
      exo/inference/tinygrad/inference.py

+ 30 - 2
exo/inference/tinygrad/inference.py

@@ -7,7 +7,7 @@ from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
 from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
-from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
+from tinygrad.helpers import DEBUG, tqdm, _cache_dir
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
@@ -186,7 +186,35 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
             return
 
         model_path = Path(shard.model_id)
-        size = "8B" # one of 8B or 70B for now
+
+        models_dir = Path(_cache_dir) / "downloads"
+        model_path = models_dir / shard.model_id
+        if model_path.exists():
+            model = model_path
+        else:
+            from tinygrad.helpers import fetch
+
+            if DEBUG >= 2: print(f"Fetching configuration for model {shard.model_id}...")
+            if shard.model_id == "llama3-8b-sfr":
+                fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+                fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+                model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+                size = "8B"
+            elif shard.model_id == "llama3-70b-sfr":
+                raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
+                # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+                # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+                # size = "70B"
+            else:
+                raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
+
         model = build_transformer(model_path, shard=shard, model_size=size)
         tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))