1
0
Эх сурвалжийг харах

always call convert_from_huggingface with tinygrad models. this was broken by shard layer filtering which made the check sometimes fail. fixes #144

Alex Cheema 8 сар өмнө
parent
commit
803dffd1c4

+ 2 - 1
.circleci/config.yml

@@ -109,7 +109,8 @@ jobs:
           name: Run tests
           command: |
             source env/bin/activate
-            METAL_XCODE=1 python3 -m exo.inference.test_inference_engine
+            # set TEMPERATURE to 0 for deterministic sampling
+            METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
 
   discovery_integration_test:
     macos:

+ 19 - 15
exo/inference/test_inference_engine.py

@@ -1,17 +1,16 @@
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
+from exo.helpers import DEBUG
 import asyncio
 import numpy as np
-
+from transformers import AutoTokenizer
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, tokenizer: AutoTokenizer):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -19,22 +18,23 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     inference_state=inference_state_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  pp = 15
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
     input_data=resp1,
     inference_state=inference_state_1,
   )
   resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32),
     input_data=resp2,
     inference_state=inference_state_2,
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
     input_data=resp3,
     inference_state=inference_state_3,
   )
@@ -42,7 +42,6 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
 
-
 asyncio.run(
   test_inference_engine(
     MLXDynamicShardInferenceEngine(HFShardDownloader()),
@@ -51,9 +50,14 @@ asyncio.run(
   )
 )
 
-# TODO: Need more memory or a smaller model
-# asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-#     "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-# ))
+if os.getenv("RUN_TINYGRAD", default="0") == "1":
+  import tinygrad
+  import os
+  from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+  tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+  asyncio.run(test_inference_engine(
+      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+      AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+  ))

+ 10 - 9
exo/inference/tinygrad/inference.py

@@ -1,6 +1,6 @@
 from pathlib import Path
-from typing import List
 import json
+import os
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.shard import Shard
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -14,7 +14,7 @@ from exo.download.shard_download import ShardDownloader
 
 Tensor.no_grad = True
 # default settings
-TEMPERATURE = 0.85
+TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
 TOP_K = 25
 TOP_P = 0.9
 ALPHA_F = 0.1
@@ -43,8 +43,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
     weights = load(str(model_path), shard)
-  if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+  weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
@@ -63,15 +62,16 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = self.tokenizer.encode(prompt)
-    h = self.model(Tensor([toks]), start_pos, TEMPERATURE)
+    h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
 
     if h.shape == (1,):
       start_pos += len(toks)
-      n_captured_toks = 1
+      start_pos += 1
+      n_captured_toks = 0
       return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
     else:
-      n_captured_toks += len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks + 1}), False
+      n_captured_toks = len(toks)
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
@@ -82,7 +82,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
     if h.shape == (1,):
       start_pos += n_captured_toks
-      n_captured_toks = 1
+      start_pos += 1
+      n_captured_toks = 0
       return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
     else:
       return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False

+ 1 - 1
exo/inference/tinygrad/models/llama.py

@@ -107,7 +107,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
   # if temperature is very low just use argmax
-  if temp < 1e-6: return logits.argmax()
+  if temp < 1e-6: return logits.argmax().reshape(1)
 
   # alpha sampling
   if af or ap:

+ 1 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -35,7 +35,7 @@ def load(fn:str, shard: Shard):
 
       parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
       filtered_weight_map[k] = n
-    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {set(weight_map.keys()) - set(filtered_weight_map.keys())}")
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
   elif fn.endswith(".safetensors"):
     return safe_load(fn)