@@ -12,6 +12,7 @@ class InferenceEngine(ABC):
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass
+ @abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray: