Browse Source

DummyInferenceEngine commit 1

rahat2134 8 months ago
parent
commit
7d6104750a

+ 1 - 1
.gitignore

@@ -1,5 +1,5 @@
 __pycache__/
 __pycache__/
-.venv
+.venv*
 test_weights.npz
 test_weights.npz
 .exo_used_ports
 .exo_used_ports
 .exo_node_id
 .exo_node_id

+ 68 - 0
exo/inference/dummy_inference_engine.py

@@ -0,0 +1,68 @@
+from typing import Optional, Tuple, TYPE_CHECKING
+import numpy as np
+import asyncio
+import json
+from typing import Optional, Tuple
+if TYPE_CHECKING:
+    from exo.inference.inference_engine import InferenceEngine
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+
+class DummyInferenceEngine(InferenceEngine):
+    def __init__(self, shard_downloader):
+        self.shard = None
+        self.shard_downloader = shard_downloader
+        self.vocab_size = 1000
+        self.eos_token_id = 0
+        self.latency_mean = 0.1
+        self.latency_stddev = 0.02
+
+    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+        try:
+            await self.ensure_shard(shard)
+            
+            # Generate random tokens
+            output_length = np.random.randint(1, 10)
+            output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+            
+            # Simulate latency
+            await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+            
+            # Randomly decide if finished
+            is_finished = np.random.random() < 0.2
+            if is_finished:
+                output = np.array([[self.eos_token_id]])
+            
+            new_state = json.dumps({"dummy_state": "some_value"})
+            
+            return output, new_state, is_finished
+        except Exception as e:
+            print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
+            return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+
+    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+        await self.ensure_shard(shard)
+        state = json.loads(inference_state or "{}")
+        start_pos = state.get("start_pos", 0)
+        
+        output_length = np.random.randint(1, 10)
+        output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+        
+        await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+        
+        is_finished = np.random.random() < 0.2
+        if is_finished:
+            output = np.array([[self.eos_token_id]])
+        
+        start_pos += input_data.shape[1] + output_length
+        new_state = json.dumps({"start_pos": start_pos})
+        
+        return output, new_state, is_finished
+
+    async def ensure_shard(self, shard: Shard):
+        if self.shard == shard:
+            return
+        # Simulate shard loading without making any API calls
+        await asyncio.sleep(0.1)  # Simulate a short delay
+        self.shard = shard
+        print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 6 - 2
exo/inference/inference_engine.py

@@ -8,7 +8,7 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
@@ -17,6 +17,7 @@ class InferenceEngine(ABC):
 
 
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  print(f"get_inference_engine called with: {inference_engine_name}")  # Debug print
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
 
@@ -27,5 +28,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
 
     return TinygradDynamicShardInferenceEngine(shard_downloader)
     return TinygradDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "dummy":
+    from exo.inference.dummy_inference_engine import DummyInferenceEngine
+    return DummyInferenceEngine(shard_downloader)
   else:
   else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
+      raise ValueError(f"Inference engine {inference_engine_name} not supported. Supported engines are 'mlx', 'tinygrad', and 'dummy'.")

+ 40 - 0
exo/inference/test_dummy_inference_engine.py

@@ -0,0 +1,40 @@
+import pytest
+import numpy as np
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
+from exo.inference.shard import Shard
+
+@pytest.mark.asyncio
+async def test_dummy_inference_engine():
+    # Create a mock shard downloader
+    class MockShardDownloader:
+        async def ensure_shard(self, shard):
+            pass
+
+    # Initialize the DummyInferenceEngine
+    engine = DummyInferenceEngine(MockShardDownloader())
+    
+    # Create a test shard
+    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+    
+    # Test infer_prompt
+    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
+    
+    assert isinstance(output, np.ndarray), "Output should be a numpy array"
+    assert output.ndim == 2, "Output should be 2-dimensional"
+    assert isinstance(state, str), "State should be a string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+    # Test infer_tensor
+    input_tensor = np.array([[1, 2, 3]])
+    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
+    
+    assert isinstance(output, np.ndarray), "Output should be a numpy array"
+    assert output.ndim == 2, "Output should be 2-dimensional"
+    assert isinstance(state, str), "State should be a string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+    print("All tests passed!")
+
+if __name__ == "__main__":
+    import asyncio
+    asyncio.run(test_dummy_inference_engine())

+ 25 - 1
exo/main.py

@@ -18,6 +18,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
 from exo.models import model_base_shards
 from exo.models import model_base_shards
@@ -41,13 +42,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
 args = parser.parse_args()
+print(f"Selected inference engine: {args.inference_engine}")
+
 
 
 print_yellow_exo()
 print_yellow_exo()
 
 
@@ -56,6 +59,15 @@ print(f"Detected system: {system_info}")
 
 
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
+print(f"Inference engine name after selection: {inference_engine_name}")
+
+if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
+  print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
+  inference_engine_name = "tinygrad"
+else:
+  print(f"Using selected inference engine: {inference_engine_name}")
+
+print(f"About to call get_inference_engine with: {inference_engine_name}")
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
@@ -173,6 +185,16 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     node.on_token.deregister(callback_id)
     node.on_token.deregister(callback_id)
 
 
 
 
+async def test_dummy_inference(inference_engine):
+    print("Testing DummyInferenceEngine...")
+    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+    test_prompt = "This is a test prompt"
+    result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
+    print(f"Inference result shape: {result.shape}")
+    print(f"Inference state: {state}")
+    print(f"Is finished: {is_finished}")
+
+
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
@@ -193,6 +215,8 @@ async def main():
     await run_model_cli(node, inference_engine, model_name, args.prompt)
     await run_model_cli(node, inference_engine, model_name, args.prompt)
   else:
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+    if isinstance(node.inference_engine, DummyInferenceEngine):
+      await test_dummy_inference(node.inference_engine)
     await asyncio.Event().wait()
     await asyncio.Event().wait()