فهرست منبع

Initialize inference engine session in base class

Nel Nibcord 8 ماه پیش
والد
کامیت
98edb393b2
3فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 0
      exo/inference/inference_engine.py
  2. 0 1
      exo/inference/mlx/sharded_inference_engine.py
  3. 0 1
      exo/inference/tinygrad/inference.py

+ 2 - 0
exo/inference/inference_engine.py

@@ -8,6 +8,8 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
+  session = {}
+
   @abstractmethod
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     pass

+ 0 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -42,7 +42,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.caches = OrderedDict()
-    self.session = {}
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:

+ 0 - 1
exo/inference/tinygrad/inference.py

@@ -65,7 +65,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.states = OrderedDict()
-    self.session = {}
 
   def poll_state(self, x, request_id: str, max_states=2):
     if request_id not in self.states: