瀏覽代碼

default to llama-3-8b and temperature=0 if not provided

Alex Cheema 9 月之前
父節點
當前提交
5de2ea51f5
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      exo/api/chatgpt_api.py

+ 2 - 2
exo/api/chatgpt_api.py

@@ -70,7 +70,7 @@ class ChatGPTAPI:
     async def handle_post(self, request):
     async def handle_post(self, request):
         data = await request.json()
         data = await request.json()
         messages = [Message(**msg) for msg in data['messages']]
         messages = [Message(**msg) for msg in data['messages']]
-        chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
+        chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         if not shard:
         if not shard:
@@ -137,7 +137,7 @@ class ChatGPTAPI:
         await runner.setup()
         await runner.setup()
         site = web.TCPSite(runner, host, port)
         site = web.TCPSite(runner, host, port)
         await site.start()
         await site.start()
-        if DEBUG >= 1: print(f"Starting ChatGPT API server at {host}:{port}")
+        if DEBUG >= 0: print(f"ChatGPT API server started at http://{host}:{port}")
 
 
 # Usage example
 # Usage example
 if __name__ == "__main__":
 if __name__ == "__main__":