Browse Source

add an opaque inference_state that inference engines can use to pass around small state to other devices

Alex Cheema 11 months ago
parent
commit
dd8d18128c

+ 1 - 1
exo/api/chatgpt_api.py

@@ -52,7 +52,7 @@ class ChatGPTAPI:
         except Exception as e:
             pass # TODO
             # return web.json_response({'detail': str(e)}, status=500)
-        
+
         # poll for the response. TODO: implement callback for specific request id
         timeout = 90
         start_time = time.time()

+ 8 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -13,20 +13,20 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
         model_shard, self.tokenizer = load_shard(model_path, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         if shard != self.shard:
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
 
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
         print(f"output_data size: {output_data.size}, output_data: {output_data}")
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
         if shard != self.shard:
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
 
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
     async def reset_shard(self, shard: Shard):
         if shard != self.shard:
@@ -39,15 +39,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
     async def reset_shard(self, shard: Shard):
         await self.ensure_shard(shard)

+ 11 - 17
exo/inference/test_inference_engine.py

@@ -5,32 +5,26 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import numpy as np
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
-    # inference_engine.reset_shard(Shard("", 0,0,0))
+async def test_inference_engine(inference_engine: InferenceEngine, model_id: str):
     prompt = "In a single word only, what is the capital of Japan? "
-    resp_full, _, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt=prompt)
+    resp_full, inference_state_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
 
-    print("resp_full", resp_full)
-    print("decoded", inference_engine.tokenizer.decode(resp_full))
+    await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32))
+    resp1, inference_state, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32), prompt=prompt)
 
-    # inference_engine.reset_shard(Shard("", 0,0,0))
+    await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32))
+    resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state)
 
-    resp1, inference_state, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
-    print(f"Intermediate {inference_state=}")
-    resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1, inference_state=inference_state)
-
-    # assert np.array_equal(resp_full, resp2)
+    assert np.array_equal(resp_full, resp2)
 
 import asyncio
 
-# asyncio.run(test_inference_engine(
-#     MLXDynamicShardInferenceEngine(),
-#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-#     [1234]
-# ))
+asyncio.run(test_inference_engine(
+    MLXDynamicShardInferenceEngine(),
+    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+))
 
 asyncio.run(test_inference_engine(
     TinygradDynamicShardInferenceEngine(),
     "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
-    [1234]
 ))

+ 9 - 13
exo/inference/tinygrad/inference.py

@@ -52,7 +52,8 @@ class Tokenizer:
   @property
   def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
 
-  def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
+  def decode(self, toks):
+     return self.model.decode([t for t in toks if t < self.num_base_tokens])
   def encode(self, text, allow_special=False):
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
@@ -77,11 +78,11 @@ def load(fn:str):
   else:
     return torch_load(fn)
 
-def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
   # build model
   linear = nn.Linear
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
 
   # load weights
   if model_path.is_dir():
@@ -91,7 +92,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
   else:
     weights = load(str(model_path))
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"], shard=shard)
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
@@ -117,7 +118,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
   return model
 
 # default settings
-TEMPERATURE = 0.85
+TEMPERATURE = 0 # 0.85
 TOP_K = 25
 TOP_P = 0.9
 ALPHA_F = 0.1
@@ -154,14 +155,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
             return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
 
         await self.ensure_shard(shard)
-        print([self.tokenizer.encode(prompt)])
 
         toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
         start_pos = prefill(self.model, toks[:-1])
         last_tok = toks[-1]
 
-        output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        print(f"{output_data.size=}")
+        output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         if output_data.size == 1:
            start_pos += 1
 
@@ -171,8 +170,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         await self.ensure_shard(shard)
 
         start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
-        output_data: np.ndarray = np.array(self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        print(f"{output_data.size=}")
+        output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         if output_data.size == 1:
            start_pos += 1
 
@@ -181,7 +179,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     async def reset_shard(self, shard: Shard):
         await self.ensure_shard(shard)
 
-        print(f"Resetting shard: {shard}")
         self.model.reset()
 
     async def ensure_shard(self, shard: Shard):
@@ -190,10 +187,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
         model_path = Path(shard.model_id)
         size = "8B" # one of 8B or 70B for now
-        model = build_transformer(model_path, model_size=size)
+        model = build_transformer(model_path, shard=shard, model_size=size)
         tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
 
         self.shard = shard
         self.model = model
         self.tokenizer = tokenizer
-

+ 25 - 13
exo/inference/tinygrad/models/llama.py

@@ -1,6 +1,7 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
+from exo.inference.shard import Shard
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
@@ -144,42 +145,47 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   return output_token
 
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
     self.forward_jit = TinyJit(self.forward) if jit else None
+    self.shard = shard
 
-  def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
-    _bsz, seqlen = tokens.shape
+  def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = h.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
 
-    h = self.tok_embeddings(tokens)
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(h)
     mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
+
     for i, layer in enumerate(self.layers):
       h = layer(h, start_pos, freqs_cis, mask)
-      print(f"layer {i}", h.tolist().__str__()[0:100])
-    logits = self.output(self.norm(h)).float()[:, -1, :]
 
-    return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+    if self.shard.is_last_layer():
+        logits = self.output(self.norm(h)).float()[:, -1, :]
+        return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+    else:
+      return h.realize()
 
   def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
   def reset(self):
     for layer in self.layers:
-      print(f"reset layer: {layer.attention.cache_kv}")
-      layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
+      if hasattr(layer.attention, "cache_kv"):
+        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
 
 # *** helpers ***
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
@@ -197,6 +203,12 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
+      layer_num = int(k.split('.')[2])
+      if shard.start_layer <= layer_num <= shard.end_layer:
+          k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
+      else:
+        continue
+
       if "q_proj" in k:
         v = permute(v, n_heads)
       elif "k_proj" in k:

+ 0 - 1
main.py

@@ -8,7 +8,6 @@ from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from exo.inference.shard import Shard
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI