Преглед на файлове

Abstract load checkpoint method

Nel Nibcord преди 8 месеца
родител
ревизия
6aaea8c74c
променени са 3 файла, в които са добавени 10 реда и са изтрити 0 реда
  1. 3 0
      exo/inference/dummy_inference_engine.py
  2. 4 0
      exo/inference/inference_engine.py
  3. 3 0
      exo/inference/tinygrad/inference.py

+ 3 - 0
exo/inference/dummy_inference_engine.py

@@ -32,3 +32,6 @@ class DummyInferenceEngine(InferenceEngine):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return
     self.shard = shard
+  
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)

+ 4 - 0
exo/inference/inference_engine.py

@@ -25,6 +25,10 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     pass
+
+  @abstractmethod
+  async def load_checkpoint(self, shard: Shard, path: str):
+    pass
   
   async def save_session(self, key, value):
     self.session[key] = value

+ 3 - 0
exo/inference/tinygrad/inference.py

@@ -92,6 +92,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
   
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)
+  
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():