|
@@ -1,10 +1,13 @@
|
|
|
import argparse
|
|
|
import asyncio
|
|
|
+import aiofiles
|
|
|
import signal
|
|
|
import json
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
|
+from typing import Optional
|
|
|
+from pathlib import Path
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
|
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
|
@@ -39,6 +42,7 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
|
|
|
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
|
|
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
|
+parser.add_argument("--file", type=str, help="File to use for the model when using --run-model", default=None)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
print_yellow_exo()
|
|
@@ -131,7 +135,14 @@ async def shutdown(signal, loop):
|
|
|
loop.stop()
|
|
|
|
|
|
|
|
|
-async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str, file_path: Optional[str] = None):
|
|
|
+ if file_path:
|
|
|
+ try:
|
|
|
+ import textract
|
|
|
+ prompt = "Input file: " + textract.process(file_path).decode('utf-8') + "\n\n---\n\n" + prompt
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error reading file {file_path}: {str(e)}")
|
|
|
+ return
|
|
|
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
|
|
|
if not shard:
|
|
|
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
@@ -145,7 +156,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
|
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
try:
|
|
|
- print(f"Processing prompt: {prompt}")
|
|
|
+ print(f"Processing prompt (len=${len(prompt)}): {prompt}")
|
|
|
await node.process_prompt(shard, prompt, None, request_id=request_id)
|
|
|
|
|
|
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
|
|
@@ -172,7 +183,7 @@ async def main():
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
|
|
|
if args.run_model:
|
|
|
- await run_model_cli(node, inference_engine, args.run_model, args.prompt)
|
|
|
+ await run_model_cli(node, inference_engine, args.run_model, args.prompt, args.file)
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
|
|
await asyncio.Event().wait()
|