Browse Source

add an opaque inference_state that inference engines can use to pass around small state to other devices

Alex Cheema 11 months ago
parent
commit
dd8d18128c

+ 1 - 1
exo/api/chatgpt_api.py

@@ -52,7 +52,7 @@ class ChatGPTAPI:
         except Exception as e:
         except Exception as e:
             pass # TODO
             pass # TODO
             # return web.json_response({'detail': str(e)}, status=500)
             # return web.json_response({'detail': str(e)}, status=500)
-        
+
         # poll for the response. TODO: implement callback for specific request id
         # poll for the response. TODO: implement callback for specific request id
         timeout = 90
         timeout = 90
         start_time = time.time()
         start_time = time.time()

+ 8 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -13,20 +13,20 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
         model_shard, self.tokenizer = load_shard(model_path, shard)
         model_shard, self.tokenizer = load_shard(model_path, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
 
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         if shard != self.shard:
         if shard != self.shard:
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
 
 
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
         print(f"output_data size: {output_data.size}, output_data: {output_data}")
         print(f"output_data size: {output_data.size}, output_data: {output_data}")
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
         if shard != self.shard:
         if shard != self.shard:
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
 
 
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
     async def reset_shard(self, shard: Shard):
     async def reset_shard(self, shard: Shard):
         if shard != self.shard:
         if shard != self.shard:
@@ -39,15 +39,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
     def __init__(self):
         self.shard = None
         self.shard = None
 
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
-        return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
     async def reset_shard(self, shard: Shard):
     async def reset_shard(self, shard: Shard):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)

+ 11 - 17
exo/inference/test_inference_engine.py

@@ -5,32 +5,26 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import numpy as np
 import numpy as np
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
-    # inference_engine.reset_shard(Shard("", 0,0,0))
+async def test_inference_engine(inference_engine: InferenceEngine, model_id: str):
     prompt = "In a single word only, what is the capital of Japan? "
     prompt = "In a single word only, what is the capital of Japan? "
-    resp_full, _, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt=prompt)
+    resp_full, inference_state_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
 
 
-    print("resp_full", resp_full)
-    print("decoded", inference_engine.tokenizer.decode(resp_full))
+    await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32))
+    resp1, inference_state, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32), prompt=prompt)
 
 
-    # inference_engine.reset_shard(Shard("", 0,0,0))
+    await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32))
+    resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state)
 
 
-    resp1, inference_state, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
-    print(f"Intermediate {inference_state=}")
-    resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1, inference_state=inference_state)
-
-    # assert np.array_equal(resp_full, resp2)
+    assert np.array_equal(resp_full, resp2)
 
 
 import asyncio
 import asyncio
 
 
-# asyncio.run(test_inference_engine(
-#     MLXDynamicShardInferenceEngine(),
-#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-#     [1234]
-# ))
+asyncio.run(test_inference_engine(
+    MLXDynamicShardInferenceEngine(),
+    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+))
 
 
 asyncio.run(test_inference_engine(
 asyncio.run(test_inference_engine(
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
     "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
     "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
-    [1234]
 ))
 ))

+ 9 - 13
exo/inference/tinygrad/inference.py

@@ -52,7 +52,8 @@ class Tokenizer:
   @property
   @property
   def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
   def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
 
 
-  def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
+  def decode(self, toks):
+     return self.model.decode([t for t in toks if t < self.num_base_tokens])
   def encode(self, text, allow_special=False):
   def encode(self, text, allow_special=False):
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
 
@@ -77,11 +78,11 @@ def load(fn:str):
   else:
   else:
     return torch_load(fn)
     return torch_load(fn)
 
 
-def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
   with Context(THREEFRY=0):
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
@@ -91,7 +92,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
   else:
   else:
     weights = load(str(model_path))
     weights = load(str(model_path))
   if "model.embed_tokens.weight" in weights:
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"], shard=shard)
   weights = fix_bf16(weights)
   weights = fix_bf16(weights)
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
@@ -117,7 +118,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
   return model
   return model
 
 
 # default settings
 # default settings
-TEMPERATURE = 0.85
+TEMPERATURE = 0 # 0.85
 TOP_K = 25
 TOP_K = 25
 TOP_P = 0.9
 TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
@@ -154,14 +155,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
             return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
             return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
 
 
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
-        print([self.tokenizer.encode(prompt)])
 
 
         toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
         toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
         start_pos = prefill(self.model, toks[:-1])
         start_pos = prefill(self.model, toks[:-1])
         last_tok = toks[-1]
         last_tok = toks[-1]
 
 
-        output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        print(f"{output_data.size=}")
+        output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         if output_data.size == 1:
         if output_data.size == 1:
            start_pos += 1
            start_pos += 1
 
 
@@ -171,8 +170,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
 
 
         start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
         start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
-        output_data: np.ndarray = np.array(self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        print(f"{output_data.size=}")
+        output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         if output_data.size == 1:
         if output_data.size == 1:
            start_pos += 1
            start_pos += 1
 
 
@@ -181,7 +179,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     async def reset_shard(self, shard: Shard):
     async def reset_shard(self, shard: Shard):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
 
 
-        print(f"Resetting shard: {shard}")
         self.model.reset()
         self.model.reset()
 
 
     async def ensure_shard(self, shard: Shard):
     async def ensure_shard(self, shard: Shard):
@@ -190,10 +187,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
 
         model_path = Path(shard.model_id)
         model_path = Path(shard.model_id)
         size = "8B" # one of 8B or 70B for now
         size = "8B" # one of 8B or 70B for now
-        model = build_transformer(model_path, model_size=size)
+        model = build_transformer(model_path, shard=shard, model_size=size)
         tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
         tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
 
 
         self.shard = shard
         self.shard = shard
         self.model = model
         self.model = model
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
-

+ 25 - 13
exo/inference/tinygrad/models/llama.py

@@ -1,6 +1,7 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
+from exo.inference.shard import Shard
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
@@ -144,42 +145,47 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   return output_token
   return output_token
 
 
 class Transformer:
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.max_context = max_context
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
+    self.shard = shard
 
 
-  def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
-    _bsz, seqlen = tokens.shape
+  def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = h.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
 
 
-    h = self.tok_embeddings(tokens)
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(h)
     mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
+
     for i, layer in enumerate(self.layers):
     for i, layer in enumerate(self.layers):
       h = layer(h, start_pos, freqs_cis, mask)
       h = layer(h, start_pos, freqs_cis, mask)
-      print(f"layer {i}", h.tolist().__str__()[0:100])
-    logits = self.output(self.norm(h)).float()[:, -1, :]
 
 
-    return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+    if self.shard.is_last_layer():
+        logits = self.output(self.norm(h)).float()[:, -1, :]
+        return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+    else:
+      return h.realize()
 
 
   def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
   def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
 
   def reset(self):
   def reset(self):
     for layer in self.layers:
     for layer in self.layers:
-      print(f"reset layer: {layer.attention.cache_kv}")
-      layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
+      if hasattr(layer.attention, "cache_kv"):
+        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
 
 
 # *** helpers ***
 # *** helpers ***
 
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
@@ -197,6 +203,12 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     if ".rotary_emb." in k: continue
     if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
     if "model.layers" in k:
+      layer_num = int(k.split('.')[2])
+      if shard.start_layer <= layer_num <= shard.end_layer:
+          k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
+      else:
+        continue
+
       if "q_proj" in k:
       if "q_proj" in k:
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:

+ 0 - 1
main.py

@@ -8,7 +8,6 @@ from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from exo.inference.shard import Shard
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI