Browse Source

Need tokens. Also, for some reason this gets mad if we have non-integral tokens but this isn't a problem elsewhere?

Nel Nibcord 8 months ago
parent
commit
03924cf9af

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -40,7 +40,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     y = mx.array(x)
     logits = y[:, -1, :]
-    out = np.array(sample_logits(logits, temp=temp, top_p=top_p))
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
     return out
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:

+ 4 - 2
exo/inference/test_inference_engine.py

@@ -12,10 +12,11 @@ import numpy as np
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
   next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=resp_full,
+    input_data=token_full,
   )
 
   pp = n_layers // 2
@@ -25,10 +26,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
   )
+  tokens2 = await inference_engine_1.sample(resp2)
   resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
-    input_data=resp2,
+    input_data=tokens2,
   )
   resp4 = await inference_engine_2.infer_tensor(
     "B",

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -70,7 +70,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     def sample_wrapper():
       return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
     out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-    return out.numpy()
+    return out.numpy().astype(int)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)