inference_engine.py 490 B

12345678910111213141516171819
  1. import numpy as np
  2. import mlx.nn as nn
  3. from typing import Tuple
  4. from abc import ABC, abstractmethod
  5. from .shard import Shard
  6. class InferenceEngine(ABC):
  7. @abstractmethod
  8. async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> Tuple[np.ndarray, bool]:
  9. pass
  10. @abstractmethod
  11. async def infer_prompt(self, shard: Shard, prompt: str) -> Tuple[np.ndarray, bool]:
  12. pass
  13. @abstractmethod
  14. async def reset_shard(self, shard: Shard):
  15. pass