Quellcode durchsuchen

Model loading and saving for tinygrad

Nel Nibcord vor 4 Monaten
Ursprung
Commit
329efb2381
1 geänderte Dateien mit 5 neuen und 1 gelöschten Zeilen
  1. 5 1
      exo/inference/tinygrad/inference.py

+ 5 - 1
exo/inference/tinygrad/inference.py

@@ -4,7 +4,7 @@ import os
 from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
-from tinygrad.nn.state import load_state_dict
+from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
 from tinygrad import Tensor, nn, Context, TinyJit
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
@@ -96,9 +96,13 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
+    state_dict = safe_load(path)
+    await asyncio.get_running_loop().run_in_executor(self.executor, load_state_dict, self.model, state_dict)
   
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
+    state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
+    safe_save(state_dict, path) 
   
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)