Browse Source

Hoisted caching to a wrapper class

Nel Nibcord 5 tháng trước cách đây
mục cha
commit
90518a3bbe

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -2,7 +2,7 @@ import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Dict, Optional, Tuple

+ 1 - 2
exo/inference/mlx/sharded_model.py → exo/inference/mlx/stateful_model.py

@@ -1,10 +1,9 @@
-from typing import Dict, Generator, Optional, Tuple
+from typing import Dict, Tuple
 from collections import OrderedDict
 
 import mlx.core as mx
 import mlx.nn as nn
 from mlx_lm.models.cache import make_prompt_cache
-from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 

+ 1 - 1
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 

+ 1 - 1
exo/inference/mlx/test_sharded_llava.py

@@ -7,7 +7,7 @@ from io import BytesIO
 import mlx.core as mx
 from mlx_lm.models.cache import KVCache
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 

+ 4 - 1
exo/inference/tinygrad/inference.py

@@ -12,6 +12,7 @@ import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 
 Tensor.no_grad = True
@@ -94,9 +95,11 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard)
 
     if self.shard != shard:
+      loop = asyncio.get_running_loop()
       parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
-      self.model = await asyncio.get_running_loop().run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 14 - 29
exo/inference/tinygrad/models/llama.py

@@ -1,4 +1,4 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from collections import OrderedDict
@@ -48,13 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
-def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
-  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
-  if isinstance(x.device, tuple):
-    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
-  return cache_kv.realize()
-
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
@@ -194,7 +187,6 @@ class Transformer:
     feed_forward=FeedForward,
     rope_scaling: Optional[Dict[str, float]] = None,
     tie_word_embeddings=False,
-    max_caches=2,
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
@@ -206,17 +198,17 @@ class Transformer:
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
-    self.caches = OrderedDict()
-    self.max_caches = max_caches
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], request_id: str):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
     h = x
 
-    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), self.caches[request_id]):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
@@ -226,26 +218,19 @@ class Transformer:
     else:
       return h
 
-  def init_cache(self, x: Tensor, request_id: str):
-    cache = [create_kv_cache(x, self.layers[i].attention.max_context, self.layers[i].attention.n_kv_heads, self.layers[i].attention.head_dim) for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache
-
-  def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str):
+  def embed(self, inputs: Tensor):
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(tokens)
+      h = self.tok_embeddings(inputs)
     else:
-      h = tokens
-    if request_id not in self.caches:
-      self.init_cache(h, request_id)
-    else:
-      self.caches.move_to_end(request_id)
+      h = inputs
+    return h
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
+    h = self.embed(x)
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), request_id)
-    return self.forward(h, start_pos, request_id)
+      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
+    return self.forward(h, start_pos, cache=cache)
 
 
 # *** helpers ***

+ 34 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,34 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class StatefulModel:
+  def __init__(self, model, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x: Tensor, start_pos: Variable, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.caches:
+      self.init_cache(h, request_id)
+    else:
+      self.caches.move_to_end(request_id)
+    if h.shape[0:2] == (1, 1) and self.model.forward_jit is not None:
+      return self.model.forward_jit(h, Variable("start_pos", 0, self.model.max_context).bind(start_pos), cache=self.caches[request_id])
+    return self.model.forward(h, start_pos, cache=self.caches[request_id])
+