Browse Source

Abstract load checkpoint method

Nel Nibcord 5 months ago
parent
commit
e9971f74ae

+ 3 - 0
exo/inference/dummy_inference_engine.py

@@ -32,3 +32,6 @@ class DummyInferenceEngine(InferenceEngine):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return
     self.shard = shard
+  
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)

+ 4 - 0
exo/inference/inference_engine.py

@@ -25,6 +25,10 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     pass
+
+  @abstractmethod
+  async def load_checkpoint(self, shard: Shard, path: str):
+    pass
   
   async def save_session(self, key, value):
     self.session[key] = value

+ 3 - 0
exo/inference/tinygrad/inference.py

@@ -92,6 +92,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
   
+  async def load_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)
+  
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():