Prechádzať zdrojové kódy

pass on tinygrad set_on_download_progress

Alex Cheema 1 rok pred
rodič
commit
1d54f10514

+ 5 - 1
exo/inference/inference_engine.py

@@ -1,6 +1,6 @@
 import numpy as np
 
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Callable
 from abc import ABC, abstractmethod
 from .shard import Shard
 
@@ -13,3 +13,7 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
+
+  @abstractmethod
+  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+    pass

+ 4 - 1
exo/inference/tinygrad/inference.py

@@ -1,7 +1,7 @@
 import asyncio
 from functools import partial
 from pathlib import Path
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Callable
 import json
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
@@ -294,3 +294,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard = shard
     self.model = model
     self.tokenizer = tokenizer
+
+  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+    pass