Browse Source

Hoisted caching to a wrapper class

Nel Nibcord 5 months ago
parent
commit
90518a3bbe

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -2,7 +2,7 @@ import numpy as np
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from typing import Dict, Optional, Tuple

+ 1 - 2
exo/inference/mlx/sharded_model.py → exo/inference/mlx/stateful_model.py

@@ -1,10 +1,9 @@
-from typing import Dict, Generator, Optional, Tuple
+from typing import Dict, Tuple
 from collections import OrderedDict
 from collections import OrderedDict
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from mlx_lm.models.cache import make_prompt_cache
 from mlx_lm.models.cache import make_prompt_cache
-from mlx_lm.sample_utils import top_p_sampling
 
 
 from ..shard import Shard
 from ..shard import Shard
 
 

+ 1 - 1
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 

+ 1 - 1
exo/inference/mlx/test_sharded_llava.py

@@ -7,7 +7,7 @@ from io import BytesIO
 import mlx.core as mx
 import mlx.core as mx
 from mlx_lm.models.cache import KVCache
 from mlx_lm.models.cache import KVCache
 
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 

+ 4 - 1
exo/inference/tinygrad/inference.py

@@ -12,6 +12,7 @@ import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 import asyncio
 
 
 Tensor.no_grad = True
 Tensor.no_grad = True
@@ -94,9 +95,11 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_path = await self.shard_downloader.ensure_shard(shard)
 
 
     if self.shard != shard:
     if self.shard != shard:
+      loop = asyncio.get_running_loop()
       parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
       parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
-      self.model = await asyncio.get_running_loop().run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 14 - 29
exo/inference/tinygrad/models/llama.py

@@ -1,4 +1,4 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
 from collections import OrderedDict
 from collections import OrderedDict
@@ -48,13 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
-def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
-  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
-  if isinstance(x.device, tuple):
-    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
-  return cache_kv.realize()
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
@@ -194,7 +187,6 @@ class Transformer:
     feed_forward=FeedForward,
     feed_forward=FeedForward,
     rope_scaling: Optional[Dict[str, float]] = None,
     rope_scaling: Optional[Dict[str, float]] = None,
     tie_word_embeddings=False,
     tie_word_embeddings=False,
-    max_caches=2,
   ):
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
@@ -206,17 +198,17 @@ class Transformer:
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
-    self.caches = OrderedDict()
-    self.max_caches = max_caches
 
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], request_id: str):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
     h = x
     h = x
 
 
-    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), self.caches[request_id]):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask, cache=c)
       h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
 
@@ -226,26 +218,19 @@ class Transformer:
     else:
     else:
       return h
       return h
 
 
-  def init_cache(self, x: Tensor, request_id: str):
-    cache = [create_kv_cache(x, self.layers[i].attention.max_context, self.layers[i].attention.n_kv_heads, self.layers[i].attention.head_dim) for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache
-
-  def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str):
+  def embed(self, inputs: Tensor):
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(tokens)
+      h = self.tok_embeddings(inputs)
     else:
     else:
-      h = tokens
-    if request_id not in self.caches:
-      self.init_cache(h, request_id)
-    else:
-      self.caches.move_to_end(request_id)
+      h = inputs
+    return h
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
+    h = self.embed(x)
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), request_id)
-    return self.forward(h, start_pos, request_id)
+      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
+    return self.forward(h, start_pos, cache=cache)
 
 
 
 
 # *** helpers ***
 # *** helpers ***

+ 34 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,34 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class StatefulModel:
+  def __init__(self, model, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x: Tensor, start_pos: Variable, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.caches:
+      self.init_cache(h, request_id)
+    else:
+      self.caches.move_to_end(request_id)
+    if h.shape[0:2] == (1, 1) and self.model.forward_jit is not None:
+      return self.model.forward_jit(h, Variable("start_pos", 0, self.model.max_context).bind(start_pos), cache=self.caches[request_id])
+    return self.model.forward(h, start_pos, cache=self.caches[request_id])
+