浏览代码

Abstract load checkpoint method

Nel Nibcord 5 月之前
父节点
当前提交
e9971f74ae
共有 3 个文件被更改,包括 10 次插入0 次删除
  1. 3 0
      exo/inference/dummy_inference_engine.py
  2. 4 0
      exo/inference/inference_engine.py
  3. 3 0
      exo/inference/tinygrad/inference.py

+ 3 - 0
exo/inference/dummy_inference_engine.py

@@ -32,3 +32,6 @@ class DummyInferenceEngine(InferenceEngine):
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return
     if self.shard == shard: return
     self.shard = shard
     self.shard = shard
+  
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)

+ 4 - 0
exo/inference/inference_engine.py

@@ -25,6 +25,10 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     pass
     pass
+
+  @abstractmethod
+  async def load_checkpoint(self, shard: Shard, path: str):
+    pass
   
   
   async def save_session(self, key, value):
   async def save_session(self, key, value):
     self.session[key] = value
     self.session[key] = value

+ 3 - 0
exo/inference/tinygrad/inference.py

@@ -92,6 +92,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
     return tokens
   
   
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)
+  
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     def wrap_infer():
     def wrap_infer():