Kaynağa Gözat

Made temperature and top_p available to the inference engine sample interfaces

Nel Nibcord 8 ay önce
ebeveyn
işleme
52ef6ee4a3

+ 2 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -37,10 +37,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def sample(self, x) -> np.ndarray:
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     y = mx.array(x)
     logits = y[:, -1, :]
-    out = np.array(sample_logits(logits))
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p))
     return out
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:

+ 2 - 2
exo/inference/tinygrad/inference.py

@@ -64,10 +64,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def sample(self, x: np.ndarray) -> np.ndarray:
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), TEMPERATURE, 0, 0.8, 0.0, 0.0).realize()
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
     out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
     return out.numpy()