|
@@ -18,6 +18,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
|
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
|
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.orchestration.node import Node
|
|
from exo.orchestration.node import Node
|
|
from exo.models import model_base_shards
|
|
from exo.models import model_base_shards
|
|
@@ -41,13 +42,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
|
|
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
|
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
|
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
|
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
|
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
|
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
|
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
|
|
|
|
|
|
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
|
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
|
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
|
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
|
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
|
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
|
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
|
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
+print(f"Selected inference engine: {args.inference_engine}")
|
|
|
|
+
|
|
|
|
|
|
print_yellow_exo()
|
|
print_yellow_exo()
|
|
|
|
|
|
@@ -56,6 +59,15 @@ print(f"Detected system: {system_info}")
|
|
|
|
|
|
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
|
|
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
|
|
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
|
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
|
|
|
+print(f"Inference engine name after selection: {inference_engine_name}")
|
|
|
|
+
|
|
|
|
+if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
|
|
|
|
+ print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
|
|
|
|
+ inference_engine_name = "tinygrad"
|
|
|
|
+else:
|
|
|
|
+ print(f"Using selected inference engine: {inference_engine_name}")
|
|
|
|
+
|
|
|
|
+print(f"About to call get_inference_engine with: {inference_engine_name}")
|
|
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
|
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
|
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
|
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
|
|
|
|
|
@@ -173,6 +185,16 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
node.on_token.deregister(callback_id)
|
|
node.on_token.deregister(callback_id)
|
|
|
|
|
|
|
|
|
|
|
|
+async def test_dummy_inference(inference_engine):
|
|
|
|
+ print("Testing DummyInferenceEngine...")
|
|
|
|
+ test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
|
|
|
+ test_prompt = "This is a test prompt"
|
|
|
|
+ result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
|
|
|
|
+ print(f"Inference result shape: {result.shape}")
|
|
|
|
+ print(f"Inference state: {state}")
|
|
|
|
+ print(f"Is finished: {is_finished}")
|
|
|
|
+
|
|
|
|
+
|
|
async def main():
|
|
async def main():
|
|
loop = asyncio.get_running_loop()
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
|
@@ -193,6 +215,8 @@ async def main():
|
|
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
else:
|
|
else:
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
|
|
+ if isinstance(node.inference_engine, DummyInferenceEngine):
|
|
|
|
+ await test_dummy_inference(node.inference_engine)
|
|
await asyncio.Event().wait()
|
|
await asyncio.Event().wait()
|
|
|
|
|
|
|
|
|