|
@@ -70,7 +70,7 @@ class ChatGPTAPI:
|
|
|
async def handle_post(self, request):
|
|
|
data = await request.json()
|
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
|
- chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
|
|
|
+ chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
|
|
|
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
|
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
|
if not shard:
|
|
@@ -137,7 +137,7 @@ class ChatGPTAPI:
|
|
|
await runner.setup()
|
|
|
site = web.TCPSite(runner, host, port)
|
|
|
await site.start()
|
|
|
- if DEBUG >= 1: print(f"Starting ChatGPT API server at {host}:{port}")
|
|
|
+ if DEBUG >= 0: print(f"ChatGPT API server started at http://{host}:{port}")
|
|
|
|
|
|
# Usage example
|
|
|
if __name__ == "__main__":
|