1
0
Эх сурвалжийг харах

add log_request middleware if DEBUG>=2 to chatgpt api to debug api issues, default always to llama-3.1-8b

Alex Cheema 9 сар өмнө
parent
commit
5a23376059
1 өөрчлөгдсөн 17 нэмэгдсэн , 2 устгасан
  1. 17 2
      exo/api/chatgpt_api.py

+ 17 - 2
exo/api/chatgpt_api.py

@@ -169,7 +169,17 @@ class ChatGPTAPI:
         self.app.router.add_get('/', self.handle_root)
         self.app.router.add_static('/', self.static_dir, name='static')
 
+        # Add middleware to log every request
+        self.app.middlewares.append(self.log_request)
+
+    async def log_request(self, app, handler):
+        async def middleware(request):
+            if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
+            return await handler(request)
+        return middleware
+
     async def handle_root(self, request):
+        print(f"Handling root request from {request.remote}")
         return web.FileResponse(self.static_dir / 'index.html')
 
     async def handle_post_chat_token_encode(self, request):
@@ -181,13 +191,18 @@ class ChatGPTAPI:
 
     async def handle_post_chat_completions(self, request):
         data = await request.json()
+        if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
         stream = data.get('stream', False)
         chat_request = parse_chat_request(data)
         if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
             chat_request.model = "llama-3.1-8b"
-        shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
+        if not chat_request.model or chat_request.model not in shard_mappings:
+            if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+            chat_request.model = "llama-3.1-8b"
+        shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
         if not shard:
-            return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
+            supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+            return web.json_response({'detail': f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400)
         request_id = str(uuid.uuid4())
 
         tokenizer = await resolve_tokenizer(shard.model_id)