|
@@ -95,6 +95,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
async def load_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
|
|
|
+ async def save_checkpoint(self, shard: Shard, path: str):
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
def wrap_infer():
|