|
@@ -1,6 +1,6 @@
|
|
|
import numpy as np
|
|
|
|
|
|
-from typing import Tuple, Optional
|
|
|
+from typing import Tuple, Optional, Callable
|
|
|
from abc import ABC, abstractmethod
|
|
|
from .shard import Shard
|
|
|
|
|
@@ -13,3 +13,7 @@ class InferenceEngine(ABC):
|
|
|
@abstractmethod
|
|
|
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
|
|
pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
|
|
|
+ pass
|