Преглед изворни кода

Initialize inference engine session in base class

Nel Nibcord пре 8 месеци
родитељ
комит
98edb393b2

+ 2 - 0
exo/inference/inference_engine.py

@@ -8,6 +8,8 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
+  session = {}
+
   @abstractmethod
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     pass

+ 0 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -42,7 +42,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.caches = OrderedDict()
-    self.session = {}
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:

+ 0 - 1
exo/inference/tinygrad/inference.py

@@ -65,7 +65,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.states = OrderedDict()
-    self.session = {}
 
   def poll_state(self, x, request_id: str, max_states=2):
     if request_id not in self.states: