|
@@ -160,7 +160,7 @@ class PromptSession:
|
|
|
self.prompt = prompt
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
- def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
|
|
+ def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
|
|
self.node = node
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
|
self.response_timeout = response_timeout
|
|
@@ -170,6 +170,7 @@ class ChatGPTAPI:
|
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
|
self.default_model = default_model or "llama-3.2-1b"
|
|
|
+ self.system_prompt = system_prompt
|
|
|
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
@@ -336,6 +337,10 @@ class ChatGPTAPI:
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
+ # Add system prompt if set
|
|
|
+ if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
|
|
+ chat_request.messages.insert(0, Message("system", self.system_prompt))
|
|
|
+
|
|
|
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if self.on_chat_completion_request:
|