Ver código fonte

Merge pull request #17744 from Classic298/fix-rag-full-context

Fix: Prevent RAG queries when all files are in full context
Tim Jaeryang Baek 1 semana atrás
pai
commit
cd417ca0ba
1 arquivos alterados com 40 adições e 34 exclusões
  1. 40 34
      backend/open_webui/utils/middleware.py

+ 40 - 34
backend/open_webui/utils/middleware.py

@@ -641,48 +641,53 @@ async def chat_completion_files_handler(
     sources = []
 
     if files := body.get("metadata", {}).get("files", None):
-        queries = []
-        try:
-            queries_response = await generate_queries(
-                request,
-                {
-                    "model": body["model"],
-                    "messages": body["messages"],
-                    "type": "retrieval",
-                },
-                user,
-            )
-            queries_response = queries_response["choices"][0]["message"]["content"]
+        # Check if all files are in full context mode
+        all_full_context = all(item.get("context") == "full" for item in files)
 
+        queries = []
+        if not all_full_context:
             try:
-                bracket_start = queries_response.find("{")
-                bracket_end = queries_response.rfind("}") + 1
+                queries_response = await generate_queries(
+                    request,
+                    {
+                        "model": body["model"],
+                        "messages": body["messages"],
+                        "type": "retrieval",
+                    },
+                    user,
+                )
+                queries_response = queries_response["choices"][0]["message"]["content"]
 
-                if bracket_start == -1 or bracket_end == -1:
-                    raise Exception("No JSON object found in the response")
+                try:
+                    bracket_start = queries_response.find("{")
+                    bracket_end = queries_response.rfind("}") + 1
 
-                queries_response = queries_response[bracket_start:bracket_end]
-                queries_response = json.loads(queries_response)
-            except Exception as e:
-                queries_response = {"queries": [queries_response]}
+                    if bracket_start == -1 or bracket_end == -1:
+                        raise Exception("No JSON object found in the response")
 
-            queries = queries_response.get("queries", [])
-        except:
-            pass
+                    queries_response = queries_response[bracket_start:bracket_end]
+                    queries_response = json.loads(queries_response)
+                except Exception as e:
+                    queries_response = {"queries": [queries_response]}
+
+                queries = queries_response.get("queries", [])
+            except:
+                pass
 
         if len(queries) == 0:
             queries = [get_last_user_message(body["messages"])]
 
-        await __event_emitter__(
-            {
-                "type": "status",
-                "data": {
-                    "action": "queries_generated",
-                    "queries": queries,
-                    "done": False,
-                },
-            }
-        )
+        if not all_full_context:
+            await __event_emitter__(
+                {
+                    "type": "status",
+                    "data": {
+                        "action": "queries_generated",
+                        "queries": queries,
+                        "done": False,
+                    },
+                }
+            )
 
         try:
             # Offload get_sources_from_items to a separate thread
@@ -711,7 +716,8 @@ async def chat_completion_files_handler(
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
                         hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-                        full_context=request.app.state.config.RAG_FULL_CONTEXT,
+                        full_context=all_full_context
+                        or request.app.state.config.RAG_FULL_CONTEXT,
                         user=user,
                     ),
                 )