Просмотр исходного кода

Add --system-prompt to exo cli

pepebruari 4 месяцев назад
Родитель
Сommit
fe50d4d34d
2 измененных файлов с 9 добавлено и 2 удалено
  1. 6 1
      exo/api/chatgpt_api.py
  2. 3 1
      exo/main.py

+ 6 - 1
exo/api/chatgpt_api.py

@@ -160,7 +160,7 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -170,6 +170,7 @@ class ChatGPTAPI:
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.default_model = default_model or "llama-3.2-1b"
+    self.system_prompt = system_prompt
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -336,6 +337,10 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
+    # Add system prompt if set
+    if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
+      chat_request.messages.insert(0, Message("system", self.system_prompt))
+
     prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:

+ 3 - 1
exo/main.py

@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
+parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
-  default_model=args.default_model
+  default_model=args.default_model,
+  system_prompt=args.system_prompt
 )
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None