Timothy Jaeryang Baek 2 days ago
parent
commit
1f123eb100
2 changed files with 46 additions and 33 deletions
  1. 31 24
      backend/open_webui/retrieval/utils.py
  2. 15 9
      backend/open_webui/utils/middleware.py

+ 31 - 24
backend/open_webui/retrieval/utils.py

@@ -460,20 +460,19 @@ def get_sources_from_files(
     )
     )
 
 
     extracted_collections = []
     extracted_collections = []
-    relevant_contexts = []
+    query_results = []
 
 
     for file in files:
     for file in files:
-
-        context = None
+        query_result = None
         if file.get("docs"):
         if file.get("docs"):
             # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
             # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
-            context = {
+            query_result = {
                 "documents": [[doc.get("content") for doc in file.get("docs")]],
                 "documents": [[doc.get("content") for doc in file.get("docs")]],
                 "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
                 "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
             }
             }
         elif file.get("context") == "full":
         elif file.get("context") == "full":
             # Manual Full Mode Toggle
             # Manual Full Mode Toggle
-            context = {
+            query_result = {
                 "documents": [[file.get("file").get("data", {}).get("content")]],
                 "documents": [[file.get("file").get("data", {}).get("content")]],
                 "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
                 "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
             }
             }
@@ -500,7 +499,7 @@ def get_sources_from_files(
                             }
                             }
                         )
                         )
 
 
-                context = {
+                query_result = {
                     "documents": [documents],
                     "documents": [documents],
                     "metadatas": [metadatas],
                     "metadatas": [metadatas],
                 }
                 }
@@ -508,7 +507,7 @@ def get_sources_from_files(
             elif file.get("id"):
             elif file.get("id"):
                 file_object = Files.get_file_by_id(file.get("id"))
                 file_object = Files.get_file_by_id(file.get("id"))
                 if file_object:
                 if file_object:
-                    context = {
+                    query_result = {
                         "documents": [[file_object.data.get("content", "")]],
                         "documents": [[file_object.data.get("content", "")]],
                         "metadatas": [
                         "metadatas": [
                             [
                             [
@@ -521,7 +520,7 @@ def get_sources_from_files(
                         ],
                         ],
                     }
                     }
             elif file.get("file").get("data"):
             elif file.get("file").get("data"):
-                context = {
+                query_result = {
                     "documents": [[file.get("file").get("data", {}).get("content")]],
                     "documents": [[file.get("file").get("data", {}).get("content")]],
                     "metadatas": [
                     "metadatas": [
                         [file.get("file").get("data", {}).get("metadata", {})]
                         [file.get("file").get("data", {}).get("metadata", {})]
@@ -549,19 +548,27 @@ def get_sources_from_files(
 
 
             if full_context:
             if full_context:
                 try:
                 try:
-                    context = get_all_items_from_collections(collection_names)
+                    query_result = get_all_items_from_collections(collection_names)
                 except Exception as e:
                 except Exception as e:
                     log.exception(e)
                     log.exception(e)
 
 
             else:
             else:
                 try:
                 try:
-                    context = None
+                    query_result = None
                     if file.get("type") == "text":
                     if file.get("type") == "text":
-                        context = file["content"]
+                        # Not sure when this is used, but it seems to be a fallback
+                        query_result = {
+                            "documents": [
+                                [file.get("file").get("data", {}).get("content")]
+                            ],
+                            "metadatas": [
+                                [file.get("file").get("data", {}).get("meta", {})]
+                            ],
+                        }
                     else:
                     else:
                         if hybrid_search:
                         if hybrid_search:
                             try:
                             try:
-                                context = query_collection_with_hybrid_search(
+                                query_result = query_collection_with_hybrid_search(
                                     collection_names=collection_names,
                                     collection_names=collection_names,
                                     queries=queries,
                                     queries=queries,
                                     embedding_function=embedding_function,
                                     embedding_function=embedding_function,
@@ -577,8 +584,8 @@ def get_sources_from_files(
                                     " non hybrid search as fallback."
                                     " non hybrid search as fallback."
                                 )
                                 )
 
 
-                        if (not hybrid_search) or (context is None):
-                            context = query_collection(
+                        if (not hybrid_search) or (query_result is None):
+                            query_result = query_collection(
                                 collection_names=collection_names,
                                 collection_names=collection_names,
                                 queries=queries,
                                 queries=queries,
                                 embedding_function=embedding_function,
                                 embedding_function=embedding_function,
@@ -589,24 +596,24 @@ def get_sources_from_files(
 
 
             extracted_collections.extend(collection_names)
             extracted_collections.extend(collection_names)
 
 
-        if context:
+        if query_result:
             if "data" in file:
             if "data" in file:
                 del file["data"]
                 del file["data"]
 
 
-            relevant_contexts.append({**context, "file": file})
+            query_results.append({**query_result, "file": file})
 
 
     sources = []
     sources = []
-    for context in relevant_contexts:
+    for query_result in query_results:
         try:
         try:
-            if "documents" in context:
-                if "metadatas" in context:
+            if "documents" in query_result:
+                if "metadatas" in query_result:
                     source = {
                     source = {
-                        "source": context["file"],
-                        "document": context["documents"][0],
-                        "metadata": context["metadatas"][0],
+                        "source": query_result["file"],
+                        "document": query_result["documents"][0],
+                        "metadata": query_result["metadatas"][0],
                     }
                     }
-                    if "distances" in context and context["distances"]:
-                        source["distances"] = context["distances"][0]
+                    if "distances" in query_result and query_result["distances"]:
+                        source["distances"] = query_result["distances"][0]
 
 
                     sources.append(source)
                     sources.append(source)
         except Exception as e:
         except Exception as e:

+ 15 - 9
backend/open_webui/utils/middleware.py

@@ -718,6 +718,10 @@ def apply_params_to_form_data(form_data, model):
 
 
 
 
 async def process_chat_payload(request, form_data, user, metadata, model):
 async def process_chat_payload(request, form_data, user, metadata, model):
+    # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
+    # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
+    # -> Chat Files
+
     form_data = apply_params_to_form_data(form_data, model)
     form_data = apply_params_to_form_data(form_data, model)
     log.debug(f"form_data: {form_data}")
     log.debug(f"form_data: {form_data}")
 
 
@@ -911,7 +915,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                     request, form_data, extra_params, user, models, tools_dict
                     request, form_data, extra_params, user, models, tools_dict
                 )
                 )
                 sources.extend(flags.get("sources", []))
                 sources.extend(flags.get("sources", []))
-
             except Exception as e:
             except Exception as e:
                 log.exception(e)
                 log.exception(e)
 
 
@@ -924,24 +927,27 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     # If context is not empty, insert it into the messages
     # If context is not empty, insert it into the messages
     if len(sources) > 0:
     if len(sources) > 0:
         context_string = ""
         context_string = ""
-        citation_idx = {}
+        citation_idx_map = {}
+
         for source in sources:
         for source in sources:
             if "document" in source:
             if "document" in source:
-                for doc_context, doc_meta in zip(
+                for document_text, document_metadata in zip(
                     source["document"], source["metadata"]
                     source["document"], source["metadata"]
                 ):
                 ):
                     source_name = source.get("source", {}).get("name", None)
                     source_name = source.get("source", {}).get("name", None)
-                    citation_id = (
-                        doc_meta.get("source", None)
+                    source_id = (
+                        document_metadata.get("source", None)
                         or source.get("source", {}).get("id", None)
                         or source.get("source", {}).get("id", None)
                         or "N/A"
                         or "N/A"
                     )
                     )
-                    if citation_id not in citation_idx:
-                        citation_idx[citation_id] = len(citation_idx) + 1
+
+                    if source_id not in citation_idx_map:
+                        citation_idx_map[source_id] = len(citation_idx_map) + 1
+
                     context_string += (
                     context_string += (
-                        f'<source id="{citation_idx[citation_id]}"'
+                        f'<source id="{citation_idx_map[source_id]}"'
                         + (f' name="{source_name}"' if source_name else "")
                         + (f' name="{source_name}"' if source_name else "")
-                        + f">{doc_context}</source>\n"
+                        + f">{document_text}</source>\n"
                     )
                     )
 
 
         context_string = context_string.strip()
         context_string = context_string.strip()