1
0
Эх сурвалжийг харах

Dummied up an abstact save_checkpoint

Nel Nibcord 4 сар өмнө
parent
commit
bd3114457f

+ 3 - 0
exo/inference/inference_engine.py

@@ -29,6 +29,9 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def load_checkpoint(self, shard: Shard, path: str):
     pass
+
+  async def save_checkpoint(self, shard: Shard, path: str):
+    pass
   
   async def save_session(self, key, value):
     self.session[key] = value

+ 3 - 0
exo/inference/tinygrad/inference.py

@@ -95,6 +95,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
   
+  async def save_checkpoint(self, shard: Shard, path: str):
+    await self.ensure_shard(shard)
+  
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():