inference_engine.py 418 B

1234567891011121314151617
  1. import numpy as np
  2. import mlx.nn as nn
  3. from abc import ABC, abstractmethod
  4. from .shard import Shard
  5. class InferenceEngine(ABC):
  6. @abstractmethod
  7. async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  8. pass
  9. async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
  10. pass
  11. @abstractmethod
  12. async def reset_shard(self, shard: Shard):
  13. pass