|
@@ -4,7 +4,7 @@ import os
|
|
|
from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
-from tinygrad.nn.state import load_state_dict
|
|
|
+from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
|
|
|
from tinygrad import Tensor, nn, Context, TinyJit
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
import numpy as np
|
|
@@ -96,9 +96,13 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def load_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
+ state_dict = safe_load(path)
|
|
|
+ await asyncio.get_running_loop().run_in_executor(self.executor, load_state_dict, self.model, state_dict)
|
|
|
|
|
|
async def save_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
+ state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
|
|
+ safe_save(state_dict, path)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|