inference_engine.py 438 B

123456789101112131415161718
  1. import numpy as np
  2. import mlx.nn as nn
  3. from abc import ABC, abstractmethod
  4. from .shard import Shard
  5. class InferenceEngine(ABC):
  6. @abstractmethod
  7. async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  8. pass
  9. @abstractmethod
  10. async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
  11. pass
  12. @abstractmethod
  13. async def reset_shard(self, shard: Shard):
  14. pass