Ver código fonte

Removed tinygrad StatefulModel class, as it's no longer used

Nel Nibcord 8 meses atrás
pai
commit
b7bbda3348

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -11,7 +11,7 @@ import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
-from .stateful_model import StatefulModel, make_prompt_state
+from .stateful_model import make_prompt_state
 from .losses import length_masked_ce_loss
 from collections import OrderedDict
 import asyncio

+ 0 - 31
exo/inference/tinygrad/stateful_model.py

@@ -20,34 +20,3 @@ def make_prompt_state(x, model, shard):
   cache = [create_kv_cache(x, model.layers[i].attention.max_context, model.layers[i].attention.n_kv_heads, model.layers[i].attention.head_dim) for i in range(shard.start_layer, shard.end_layer + 1)]
 
   return ModelState(cache)
-
-  
-
-class StatefulModel:
-  def __init__(self, model, max_states: int = 2):
-    super().__init__()
-    self.model = model
-    self.max_states = max_states
-    self.states = OrderedDict()
- 
-  def init_cache(self, x: Tensor, request_id: str):
-    if len(self.states) >= self.max_states:
-      self.states.popitem(last=False)
-
-    self.states[request_id] = make_prompt_state(self.model, self.model.shard)
-
-  def __call__(self, x: Tensor, request_id: Optional[str] = None, use_cache: bool = True): 
-    h = self.model.embed(x)
-    #print(f"StatefulModel in <- {h}")
-    if use_cache and request_id is not None:
-      if request_id not in self.states:
-        self.init_cache(h, request_id)
-      else:
-        self.states.move_to_end(request_id)
-      out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
-      self.states[request_id].start += h.shape[1]
-    else:
-      out = self.model.forward(h, 0)
-    #print(f"StatefulModel out -> {out}")
-    return out
-