inference_engine.py 632 B

12345678910111213141516171819
  1. import numpy as np
  2. from typing import Tuple, Optional, Callable
  3. from abc import ABC, abstractmethod
  4. from .shard import Shard
  5. class InferenceEngine(ABC):
  6. @abstractmethod
  7. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
  8. pass
  9. @abstractmethod
  10. async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
  11. pass
  12. @abstractmethod
  13. def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
  14. pass