Browse Source

feedback 1- changes requested, done

rahat2134 9 months ago
parent
commit
0fdd5c7c2d

+ 0 - 3
exo/inference/dummy_inference_engine.py

@@ -2,9 +2,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
 import asyncio
 import json
-from typing import Optional, Tuple
-if TYPE_CHECKING:
-    from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 

+ 4 - 3
exo/inference/inference_engine.py

@@ -1,5 +1,6 @@
 import numpy as np
 import os
+from exo.helpers import DEBUG  # Make sure to import DEBUG
 
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
@@ -17,7 +18,8 @@ class InferenceEngine(ABC):
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
-  print(f"get_inference_engine called with: {inference_engine_name}")  # Debug print
+  if DEBUG >= 2:
+    print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
@@ -31,5 +33,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine(shard_downloader)
-  else:
-      raise ValueError(f"Inference engine {inference_engine_name} not supported. Supported engines are 'mlx', 'tinygrad', and 'dummy'.")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 22 - 6
exo/inference/test_dummy_inference_engine.py

@@ -1,15 +1,30 @@
 import pytest
+import json
 import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
+class MockShardDownloader:
+    async def ensure_shard(self, shard):
+        pass
 @pytest.mark.asyncio
-async def test_dummy_inference_engine():
-    # Create a mock shard downloader
-    class MockShardDownloader:
-        async def ensure_shard(self, shard):
-            pass
+async def test_dummy_inference_specific():
+    engine = DummyInferenceEngine(MockShardDownloader())
+    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+    test_prompt = "This is a test prompt"
+    
+    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
+    
+    print(f"Inference result shape: {result.shape}")
+    print(f"Inference state: {state}")
+    print(f"Is finished: {is_finished}")
+    
+    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
+    assert isinstance(is_finished, bool), "is_finished should be a boolean"
 
+@pytest.mark.asyncio
+async def test_dummy_inference_engine():
     # Initialize the DummyInferenceEngine
     engine = DummyInferenceEngine(MockShardDownloader())
     
@@ -37,4 +52,5 @@ async def test_dummy_inference_engine():
 
 if __name__ == "__main__":
     import asyncio
-    asyncio.run(test_dummy_inference_engine())
+    asyncio.run(test_dummy_inference_engine())
+    asyncio.run(test_dummy_inference_specific())

+ 0 - 19
exo/main.py

@@ -61,13 +61,6 @@ shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
-if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
-  print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
-  inference_engine_name = "tinygrad"
-else:
-  print(f"Using selected inference engine: {inference_engine_name}")
-
-print(f"About to call get_inference_engine with: {inference_engine_name}")
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
@@ -185,16 +178,6 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     node.on_token.deregister(callback_id)
 
 
-async def test_dummy_inference(inference_engine):
-    print("Testing DummyInferenceEngine...")
-    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    test_prompt = "This is a test prompt"
-    result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
-    print(f"Inference result shape: {result.shape}")
-    print(f"Inference state: {state}")
-    print(f"Is finished: {is_finished}")
-
-
 async def main():
   loop = asyncio.get_running_loop()
 
@@ -215,8 +198,6 @@ async def main():
     await run_model_cli(node, inference_engine, model_name, args.prompt)
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
-    if isinstance(node.inference_engine, DummyInferenceEngine):
-      await test_dummy_inference(node.inference_engine)
     await asyncio.Event().wait()