Browse Source

fix dummy test

Alex Cheema 5 months ago
parent
commit
ffe78f6d0b

+ 2 - 2
exo/inference/dummy_inference_engine.py

@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
-    return input_data + 1 if self.shard.is_last_layer() else input_data
+    return input_data + 1 if self.shard.is_last_layer() else input_data, None
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return

+ 0 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -84,7 +84,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     x = mx.array(input_data)
     if self.model.model_type != 'StableDiffusionPipeline':
       output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
-      inference_state = {}
     else:
       output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
     output_data = np.array(output_data)

+ 5 - 11
exo/inference/test_dummy_inference_engine.py

@@ -1,22 +1,16 @@
 import pytest
-import json
 import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
 
-class MockShardDownloader:
-  async def ensure_shard(self, shard):
-    pass
-
-
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
   test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   test_prompt = "This is a test prompt"
 
-  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+  result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
 
   print(f"Inference result shape: {result.shape}")
 
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
   # Initialize the DummyInferenceEngine
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
 
   # Create a test shard
   shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
 
   # Test infer_prompt
-  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+  output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
 
   # Test infer_tensor
   input_tensor = np.array([[1, 2, 3]])
-  output = await engine.infer_tensor("test_id", shard, input_tensor)
+  output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"

+ 5 - 5
exo/inference/test_inference_engine.py

@@ -11,7 +11,7 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = token_full.reshape(1, -1)
   next_resp_full = await inference_engine_1.infer_tensor(
@@ -21,20 +21,20 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   )
 
   pp = n_layers // 2
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
   )
   tokens2 = await inference_engine_1.sample(resp2)
   tokens2 = tokens2.reshape(1, -1)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=tokens2,
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,

+ 1 - 1
exo/orchestration/node.py

@@ -320,7 +320,7 @@ class Node:
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
         else:
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           self.outstanding_requests[request_id] = "training"