Browse Source

Merge pull request #364 from rahat2134/DummyInferenceEngine

Implementation of DummyInferenceEngine
Alex Cheema 9 months ago
parent
commit
727d7fffaf

+ 33 - 3
.circleci/config.yml

@@ -44,6 +44,13 @@ commands:
             # Check processes before proceeding
             check_processes
 
+            # Special handling for dummy engine
+            if [ "<<parameters.inference_engine>>" = "dummy" ]; then
+              expected_content="This is a dummy response"
+            else
+              expected_content="Michael Jackson"
+            fi
+
             echo "Sending request to first instance..."
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
               -H "Content-Type: application/json" \
@@ -74,8 +81,8 @@ commands:
             kill $PID1 $PID2
 
             echo ""
-            if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
-              echo "Test failed: Response does not contain 'Michael Jackson'"
+            if ! echo "$response_1" | grep -q "$expected_content" || ! echo "$response_2" | grep -q "$expected_content"; then
+              echo "Test failed: Response does not contain '$expected_content'"
               echo "Response 1: $response_1"
               echo ""
               echo "Response 2: $response_2"
@@ -85,7 +92,7 @@ commands:
               cat output2.log
               exit 1
             else
-              echo "Test passed: Response from both nodes contains 'Michael Jackson'"
+              echo "Test passed: Response from both nodes contains '$expected_content'"
             fi
 
 jobs:
@@ -178,6 +185,28 @@ jobs:
           inference_engine: mlx
           model_id: llama-3.2-1b
 
+  chatgpt_api_integration_test_dummy:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: dummy
+          model_id: dummy-model
+
   test_macos_m1:
     macos:
       xcode: "16.0.0"
@@ -215,5 +244,6 @@ workflows:
       - unit_test
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
+      - chatgpt_api_integration_test_dummy
       - test_macos_m1
       # - chatgpt_api_integration_test_tinygrad

+ 1 - 1
.gitignore

@@ -1,5 +1,5 @@
 __pycache__/
-.venv
+.venv*
 test_weights.npz
 .exo_used_ports
 .exo_node_id

+ 65 - 0
exo/inference/dummy_inference_engine.py

@@ -0,0 +1,65 @@
+from typing import Optional, Tuple, TYPE_CHECKING
+import numpy as np
+import asyncio
+import json
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+
+class DummyInferenceEngine(InferenceEngine):
+    def __init__(self, shard_downloader):
+        self.shard = None
+        self.shard_downloader = shard_downloader
+        self.vocab_size = 1000
+        self.eos_token_id = 0
+        self.latency_mean = 0.1
+        self.latency_stddev = 0.02
+
+    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+        try:
+            await self.ensure_shard(shard)
+            
+            # Generate random tokens
+            output_length = np.random.randint(1, 10)
+            output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+            
+            # Simulate latency
+            await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+            
+            # Randomly decide if finished
+            is_finished = np.random.random() < 0.2
+            if is_finished:
+                output = np.array([[self.eos_token_id]])
+            
+            new_state = json.dumps({"dummy_state": "some_value"})
+            
+            return output, new_state, is_finished
+        except Exception as e:
+            print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
+            return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+
+    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+        await self.ensure_shard(shard)
+        state = json.loads(inference_state or "{}")
+        start_pos = state.get("start_pos", 0)
+        
+        output_length = np.random.randint(1, 10)
+        output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+        
+        await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+        
+        is_finished = np.random.random() < 0.2
+        if is_finished:
+            output = np.array([[self.eos_token_id]])
+        
+        start_pos += input_data.shape[1] + output_length
+        new_state = json.dumps({"start_pos": start_pos})
+        
+        return output, new_state, is_finished
+
+    async def ensure_shard(self, shard: Shard):
+        if self.shard == shard:
+            return
+        # Simulate shard loading without making any API calls
+        await asyncio.sleep(0.1)  # Simulate a short delay
+        self.shard = shard
+        print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 8 - 3
exo/inference/inference_engine.py

@@ -1,5 +1,6 @@
 import numpy as np
 import os
+from exo.helpers import DEBUG  # Make sure to import DEBUG
 
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
@@ -8,7 +9,7 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
 
   @abstractmethod
@@ -17,6 +18,8 @@ class InferenceEngine(ABC):
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if DEBUG >= 2:
+    print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
@@ -27,5 +30,7 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
     return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
+  elif inference_engine_name == "dummy":
+    from exo.inference.dummy_inference_engine import DummyInferenceEngine
+    return DummyInferenceEngine(shard_downloader)
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 56 - 0
exo/inference/test_dummy_inference_engine.py

@@ -0,0 +1,56 @@
+import pytest
+import json
+import numpy as np
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
+from exo.inference.shard import Shard
+
+class MockShardDownloader:
+    async def ensure_shard(self, shard):
+        pass
+@pytest.mark.asyncio
+async def test_dummy_inference_specific():
+    engine = DummyInferenceEngine(MockShardDownloader())
+    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+    test_prompt = "This is a test prompt"
+    
+    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
+    
+    print(f"Inference result shape: {result.shape}")
+    print(f"Inference state: {state}")
+    print(f"Is finished: {is_finished}")
+    
+    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+@pytest.mark.asyncio
+async def test_dummy_inference_engine():
+    # Initialize the DummyInferenceEngine
+    engine = DummyInferenceEngine(MockShardDownloader())
+    
+    # Create a test shard
+    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+    
+    # Test infer_prompt
+    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
+    
+    assert isinstance(output, np.ndarray), "Output should be a numpy array"
+    assert output.ndim == 2, "Output should be 2-dimensional"
+    assert isinstance(state, str), "State should be a string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+    # Test infer_tensor
+    input_tensor = np.array([[1, 2, 3]])
+    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
+    
+    assert isinstance(output, np.ndarray), "Output should be a numpy array"
+    assert output.ndim == 2, "Output should be 2-dimensional"
+    assert isinstance(state, str), "State should be a string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+    print("All tests passed!")
+
+if __name__ == "__main__":
+    import asyncio
+    asyncio.run(test_dummy_inference_engine())
+    asyncio.run(test_dummy_inference_specific())

+ 6 - 1
exo/main.py

@@ -20,6 +20,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import model_base_shards
@@ -44,13 +45,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
+print(f"Selected inference engine: {args.inference_engine}")
+
 
 print_yellow_exo()
 
@@ -60,6 +63,8 @@ print(f"Detected system: {system_info}")
 
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
+print(f"Inference engine name after selection: {inference_engine_name}")
+
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")