|
@@ -61,13 +61,6 @@ shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_
|
|
|
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
|
|
print(f"Inference engine name after selection: {inference_engine_name}")
|
|
|
|
|
|
-if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
|
|
|
- print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
|
|
|
- inference_engine_name = "tinygrad"
|
|
|
-else:
|
|
|
- print(f"Using selected inference engine: {inference_engine_name}")
|
|
|
-
|
|
|
-print(f"About to call get_inference_engine with: {inference_engine_name}")
|
|
|
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
|
|
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
|
|
|
|
@@ -185,16 +178,6 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
|
node.on_token.deregister(callback_id)
|
|
|
|
|
|
|
|
|
-async def test_dummy_inference(inference_engine):
|
|
|
- print("Testing DummyInferenceEngine...")
|
|
|
- test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
|
|
- test_prompt = "This is a test prompt"
|
|
|
- result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
|
|
|
- print(f"Inference result shape: {result.shape}")
|
|
|
- print(f"Inference state: {state}")
|
|
|
- print(f"Is finished: {is_finished}")
|
|
|
-
|
|
|
-
|
|
|
async def main():
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
@@ -215,8 +198,6 @@ async def main():
|
|
|
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
|
- if isinstance(node.inference_engine, DummyInferenceEngine):
|
|
|
- await test_dummy_inference(node.inference_engine)
|
|
|
await asyncio.Event().wait()
|
|
|
|
|
|
|