Procházet zdrojové kódy

refac: web search performance

Co-Authored-By: Mabeck <64421281+mmabeck@users.noreply.github.com>
Timothy Jaeryang Baek před 5 měsíci
rodič
revize
34ec10a78c

+ 47 - 36
backend/open_webui/routers/retrieval.py

@@ -3,6 +3,8 @@ import logging
 import mimetypes
 import os
 import shutil
+import asyncio
+
 
 import uuid
 from datetime import datetime
@@ -188,7 +190,7 @@ class ProcessUrlForm(CollectionNameForm):
 
 
 class SearchForm(BaseModel):
-    query: str
+    queries: List[str]
 
 
 @router.get("/")
@@ -1568,16 +1570,34 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
 async def process_web_search(
     request: Request, form_data: SearchForm, user=Depends(get_verified_user)
 ):
+
+    urls = []
     try:
         logging.info(
             f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
         )
-        web_results = await run_in_threadpool(
-            search_web,
-            request,
-            request.app.state.config.WEB_SEARCH_ENGINE,
-            form_data.query,
-        )
+
+        search_tasks = [
+            run_in_threadpool(
+                search_web,
+                request,
+                request.app.state.config.WEB_SEARCH_ENGINE,
+                query,
+            )
+            for query in form_data.queries
+        ]
+
+        search_results = await asyncio.gather(*search_tasks)
+
+        for result in search_results:
+            if result:
+                for item in result:
+                    if item and item.link:
+                        urls.append(item.link)
+
+        urls = list(dict.fromkeys(urls))
+        log.debug(f"urls: {urls}")
+
     except Exception as e:
         log.exception(e)
 
@@ -1586,15 +1606,7 @@ async def process_web_search(
             detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
         )
 
-    log.debug(f"web_results: {web_results}")
-
     try:
-        urls = [result.link for result in web_results]
-
-        # Remove duplicates
-        urls = list(dict.fromkeys(urls))
-        log.debug(f"urls: {urls}")
-
         loader = get_web_loader(
             urls,
             verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
@@ -1604,7 +1616,7 @@ async def process_web_search(
         docs = await loader.aload()
         urls = [
             doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
-        ]  # only keep URLs
+        ]  # only keep the urls returned by the loader
 
         if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
             return {
@@ -1621,29 +1633,28 @@ async def process_web_search(
                 "loaded_count": len(docs),
             }
         else:
-            collection_names = []
-            for doc_idx, doc in enumerate(docs):
-                if doc and doc.page_content:
-                    try:
-                        collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
-                            :63
-                        ]
-
-                        collection_names.append(collection_name)
-                        await run_in_threadpool(
-                            save_docs_to_vector_db,
-                            request,
-                            [doc],
-                            collection_name,
-                            overwrite=True,
-                            user=user,
-                        )
-                    except Exception as e:
-                        log.debug(f"error saving doc {doc_idx}: {e}")
+            # Create a single collection for all documents
+            collection_name = (
+                f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
+                    :63
+                ]
+            )
+
+            try:
+                await run_in_threadpool(
+                    save_docs_to_vector_db,
+                    request,
+                    docs,
+                    collection_name,
+                    overwrite=True,
+                    user=user,
+                )
+            except Exception as e:
+                log.debug(f"error saving docs: {e}")
 
             return {
                 "status": True,
-                "collection_names": collection_names,
+                "collection_names": [collection_name],
                 "filenames": urls,
                 "loaded_count": len(docs),
             }

+ 49 - 82
backend/open_webui/utils/middleware.py

@@ -353,8 +353,6 @@ async def chat_web_search_handler(
         )
         return form_data
 
-    all_results = []
-
     await event_emitter(
         {
             "type": "status",
@@ -366,106 +364,75 @@ async def chat_web_search_handler(
         }
     )
 
-    gathered_results = await asyncio.gather(
-        *(
-            process_web_search(
-                request,
-                SearchForm(**{"query": searchQuery}),
-                user=user,
-            )
-            for searchQuery in queries
-        ),
-        return_exceptions=True,
-    )
+    try:
+        results = await process_web_search(
+            request,
+            SearchForm(queries=queries),
+            user=user,
+        )
 
-    for searchQuery, results in zip(queries, gathered_results):
-        try:
-            if isinstance(results, Exception):
-                raise Exception(f"Error searching {searchQuery}: {str(results)}")
+        if results:
+            files = form_data.get("files", [])
 
-            if results:
-                all_results.append(results)
-                files = form_data.get("files", [])
+            if results.get("collection_names"):
+                for col_idx, collection_name in enumerate(
+                    results.get("collection_names")
+                ):
+                    files.append(
+                        {
+                            "collection_name": collection_name,
+                            "name": ", ".join(queries),
+                            "type": "web_search",
+                            "urls": results["filenames"],
+                        }
+                    )
+            elif results.get("docs"):
+                # Invoked when bypass embedding and retrieval is set to True
+                docs = results["docs"]
+                files.append(
+                    {
+                        "docs": docs,
+                        "name": ", ".join(queries),
+                        "type": "web_search",
+                        "urls": results["filenames"],
+                    }
+                )
 
-                if results.get("collection_names"):
-                    for col_idx, collection_name in enumerate(
-                        results.get("collection_names")
-                    ):
-                        files.append(
-                            {
-                                "collection_name": collection_name,
-                                "name": searchQuery,
-                                "type": "web_search",
-                                "urls": [results["filenames"][col_idx]],
-                            }
-                        )
-                elif results.get("docs"):
-                    # Invoked when bypass embedding and retrieval is set to True
-                    docs = results["docs"]
-
-                    if len(docs) == len(results["filenames"]):
-                        # the number of docs and filenames (urls) should be the same
-                        for doc_idx, doc in enumerate(docs):
-                            files.append(
-                                {
-                                    "docs": [doc],
-                                    "name": searchQuery,
-                                    "type": "web_search",
-                                    "urls": [results["filenames"][doc_idx]],
-                                }
-                            )
-                    else:
-                        # edge case when the number of docs and filenames (urls) are not the same
-                        # this should not happen, but if it does, we will just append the docs
-                        files.append(
-                            {
-                                "docs": results.get("docs", []),
-                                "name": searchQuery,
-                                "type": "web_search",
-                                "urls": results["filenames"],
-                            }
-                        )
+            form_data["files"] = files
 
-                form_data["files"] = files
-        except Exception as e:
-            log.exception(e)
             await event_emitter(
                 {
                     "type": "status",
                     "data": {
                         "action": "web_search",
-                        "description": 'Error searching "{{searchQuery}}"',
-                        "query": searchQuery,
+                        "description": "Searched {{count}} sites",
+                        "urls": results["filenames"],
+                        "done": True,
+                    },
+                }
+            )
+        else:
+            await event_emitter(
+                {
+                    "type": "status",
+                    "data": {
+                        "action": "web_search",
+                        "description": "No search results found",
                         "done": True,
                         "error": True,
                     },
                 }
             )
 
-    if all_results:
-        urls = []
-        for results in all_results:
-            if "filenames" in results:
-                urls.extend(results["filenames"])
-
-        await event_emitter(
-            {
-                "type": "status",
-                "data": {
-                    "action": "web_search",
-                    "description": "Searched {{count}} sites",
-                    "urls": urls,
-                    "done": True,
-                },
-            }
-        )
-    else:
+    except Exception as e:
+        log.exception(e)
         await event_emitter(
             {
                 "type": "status",
                 "data": {
                     "action": "web_search",
-                    "description": "No search results found",
+                    "description": "An error occurred while searching the web",
+                    "queries": queries,
                     "done": True,
                     "error": True,
                 },