Ver código fonte

fix: pipelines

Timothy J. Baek 11 meses atrás
pai
commit
cb8c45d864

+ 0 - 5
backend/apps/openai/main.py

@@ -400,11 +400,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 
             if "pipeline" in model and model.get("pipeline"):
                 payload["user"] = {"name": user.name, "id": user.id}
-                payload["title"] = (
-                    True
-                    if payload["stream"] == False and payload["max_tokens"] == 50
-                    else False
-                )
 
             # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
             # This is a workaround until OpenAI fixes the issue with this model

+ 6 - 2
backend/main.py

@@ -315,8 +315,12 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                     else:
                         pass
 
-            if "chat_id" in data:
-                del data["chat_id"]
+            if "pipeline" not in app.state.MODELS[model_id]:
+                if "chat_id" in data:
+                    del data["chat_id"]
+
+                if "title" in data:
+                    del data["title"]
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
             # Replace the request body with the modified one

+ 4 - 1
src/lib/apis/openai/index.ts

@@ -336,6 +336,7 @@ export const generateTitle = async (
 	template: string,
 	model: string,
 	prompt: string,
+	chat_id?: string,
 	url: string = OPENAI_API_BASE_URL
 ) => {
 	let error = null;
@@ -361,7 +362,9 @@ export const generateTitle = async (
 			],
 			stream: false,
 			// Restricting the max tokens to 50 to avoid long titles
-			max_tokens: 50
+			max_tokens: 50,
+			...(chat_id && { chat_id: chat_id }),
+			title: true
 		})
 	})
 		.then(async (res) => {

+ 1 - 0
src/lib/components/chat/Chat.svelte

@@ -1118,6 +1118,7 @@
 					) + ' {{prompt}}',
 				titleModelId,
 				userPrompt,
+				$chatId,
 				titleModel?.owned_by === 'openai' ?? false
 					? `${OPENAI_API_BASE_URL}`
 					: `${OLLAMA_API_BASE_URL}/v1`