|
@@ -166,7 +166,7 @@ class PromptSession:
|
|
|
self.prompt = prompt
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
- def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
|
|
+ def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
|
|
self.node = node
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
|
self.response_timeout = response_timeout
|
|
@@ -176,6 +176,7 @@ class ChatGPTAPI:
|
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
|
self.default_model = default_model or "llama-3.2-1b"
|
|
|
+ self.system_prompt = system_prompt
|
|
|
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
@@ -253,7 +254,7 @@ class ChatGPTAPI:
|
|
|
)
|
|
|
await response.prepare(request)
|
|
|
|
|
|
- for model_name, pretty in pretty_name.items():
|
|
|
+ async def process_model(model_name, pretty):
|
|
|
if model_name in model_cards:
|
|
|
model_info = model_cards[model_name]
|
|
|
|
|
@@ -281,6 +282,12 @@ class ChatGPTAPI:
|
|
|
|
|
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
|
|
|
+ # Process all models in parallel
|
|
|
+ await asyncio.gather(*[
|
|
|
+ process_model(model_name, pretty)
|
|
|
+ for model_name, pretty in pretty_name.items()
|
|
|
+ ])
|
|
|
+
|
|
|
await response.write(b"data: [DONE]\n\n")
|
|
|
return response
|
|
|
|
|
@@ -293,7 +300,8 @@ class ChatGPTAPI:
|
|
|
)
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
|
- return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
|
|
+ models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
|
|
+ return web.json_response({"object": "list", "data": models_list})
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
@@ -345,6 +353,10 @@ class ChatGPTAPI:
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
+ # Add system prompt if set
|
|
|
+ if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
|
|
+ chat_request.messages.insert(0, Message("system", self.system_prompt))
|
|
|
+
|
|
|
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if self.on_chat_completion_request:
|
|
@@ -645,7 +657,7 @@ class ChatGPTAPI:
|
|
|
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
|
- asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
|
|
+ asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
|
|
|
|
|
return web.json_response({
|
|
|
"status": "success",
|