Ver Fonte

enh/refac: url input handling

Timothy Jaeryang Baek há 1 dia atrás
pai
commit
a2a2bafdf6

+ 7 - 0
backend/open_webui/retrieval/loaders/youtube.py

@@ -157,3 +157,10 @@ class YoutubeLoader:
             f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
         )
         raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
+
+    async def aload(self) -> Generator[Document, None, None]:
+        """Asynchronously load YouTube transcripts into `Document` objects."""
+        import asyncio
+
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self.load)

+ 39 - 1
backend/open_webui/retrieval/utils.py

@@ -6,6 +6,7 @@ import requests
 import hashlib
 from concurrent.futures import ThreadPoolExecutor
 import time
+import re
 
 from urllib.parse import quote
 from huggingface_hub import snapshot_download
@@ -16,6 +17,7 @@ from langchain_core.documents import Document
 from open_webui.config import VECTOR_DB
 from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 
+
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
 from open_webui.models.knowledge import Knowledges
@@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
 from open_webui.utils.access_control import has_access
 from open_webui.utils.misc import get_message_list
 
+from open_webui.retrieval.web.utils import get_web_loader
+from open_webui.retrieval.loaders.youtube import YoutubeLoader
+
 
 from open_webui.env import (
     SRC_LOG_LEVELS,
@@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
 from langchain_core.retrievers import BaseRetriever
 
 
+def is_youtube_url(url: str) -> bool:
+    youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
+    return re.match(youtube_regex, url) is not None
+
+
+def get_loader(request, url: str):
+    if is_youtube_url(url):
+        return YoutubeLoader(
+            url,
+            language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
+            proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
+        )
+    else:
+        return get_web_loader(
+            url,
+            verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
+            requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
+        )
+
+
+def get_content_from_url(request, url: str) -> str:
+    loader = get_loader(request, url)
+    docs = loader.load()
+    content = " ".join([doc.page_content for doc in docs])
+    return content, docs
+
+
 class VectorSearchRetriever(BaseRetriever):
     collection_name: Any
     embedding_function: Any
@@ -571,6 +603,13 @@ def get_sources_from_items(
                         "metadatas": [[{"file_id": chat.id, "name": chat.title}]],
                     }
 
+        elif item.get("type") == "url":
+            content, docs = get_content_from_url(request, item.get("url"))
+            if docs:
+                query_result = {
+                    "documents": [[content]],
+                    "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
+                }
         elif item.get("type") == "file":
             if (
                 item.get("context") == "full"
@@ -736,7 +775,6 @@ def get_sources_from_items(
                     sources.append(source)
         except Exception as e:
             log.exception(e)
-
     return sources
 
 

+ 6 - 28
backend/open_webui/routers/retrieval.py

@@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
 from open_webui.retrieval.web.external import search_external
 
 from open_webui.retrieval.utils import (
+    get_content_from_url,
     get_embedding_function,
     get_reranking_function,
     get_model_path,
@@ -1691,33 +1692,6 @@ def process_text(
         )
 
 
-def is_youtube_url(url: str) -> bool:
-    youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
-    return re.match(youtube_regex, url) is not None
-
-
-def get_loader(request, url: str):
-    if is_youtube_url(url):
-        return YoutubeLoader(
-            url,
-            language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
-            proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
-        )
-    else:
-        return get_web_loader(
-            url,
-            verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
-            requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
-        )
-
-
-def get_content_from_url(request, url: str) -> str:
-    loader = get_loader(request, url)
-    docs = loader.load()
-    content = " ".join([doc.page_content for doc in docs])
-    return content, docs
-
-
 @router.post("/process/youtube")
 @router.post("/process/web")
 def process_web(
@@ -1733,7 +1707,11 @@ def process_web(
 
         if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
             save_docs_to_vector_db(
-                request, docs, collection_name, overwrite=True, user=user
+                request,
+                docs,
+                collection_name,
+                overwrite=True,
+                user=user,
             )
         else:
             collection_name = None

+ 22 - 12
backend/open_webui/utils/middleware.py

@@ -40,7 +40,10 @@ from open_webui.routers.tasks import (
     generate_image_prompt,
     generate_chat_tags,
 )
-from open_webui.routers.retrieval import process_web_search, SearchForm
+from open_webui.routers.retrieval import (
+    process_web_search,
+    SearchForm,
+)
 from open_webui.routers.images import (
     load_b64_image_data,
     image_generations,
@@ -76,6 +79,7 @@ from open_webui.utils.task import (
 )
 from open_webui.utils.misc import (
     deep_update,
+    extract_urls,
     get_message_list,
     add_or_update_system_message,
     add_or_update_user_message,
@@ -823,7 +827,11 @@ async def chat_completion_files_handler(
 
     if files := body.get("metadata", {}).get("files", None):
         # Check if all files are in full context mode
-        all_full_context = all(item.get("context") == "full" for item in files)
+        all_full_context = all(
+            item.get("context") == "full"
+            for item in files
+            if item.get("type") == "file"
+        )
 
         queries = []
         if not all_full_context:
@@ -855,10 +863,6 @@ async def chat_completion_files_handler(
             except:
                 pass
 
-        if len(queries) == 0:
-            queries = [get_last_user_message(body["messages"])]
-
-        if not all_full_context:
             await __event_emitter__(
                 {
                     "type": "status",
@@ -870,6 +874,9 @@ async def chat_completion_files_handler(
                 }
             )
 
+        if len(queries) == 0:
+            queries = [get_last_user_message(body["messages"])]
+
         try:
             # Offload get_sources_from_items to a separate thread
             loop = asyncio.get_running_loop()
@@ -908,7 +915,6 @@ async def chat_completion_files_handler(
         log.debug(f"rag_contexts:sources: {sources}")
 
         unique_ids = set()
-
         for source in sources or []:
             if not source or len(source.keys()) == 0:
                 continue
@@ -927,7 +933,6 @@ async def chat_completion_files_handler(
                 unique_ids.add(_id)
 
         sources_count = len(unique_ids)
-
         await __event_emitter__(
             {
                 "type": "status",
@@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     tool_ids = form_data.pop("tool_ids", None)
     files = form_data.pop("files", None)
 
-    # Remove files duplicates
-    if files:
+    prompt = get_last_user_message(form_data["messages"])
+    urls = extract_urls(prompt)
+
+    if files or urls:
+        if not files:
+            files = []
+        files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
+
+        # Remove duplicate files based on their content
         files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
 
     metadata = {
@@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                     )
 
         context_string = context_string.strip()
-
-        prompt = get_last_user_message(form_data["messages"])
         if prompt is None:
             raise Exception("No user message found")
 

+ 8 - 0
backend/open_webui/utils/misc.py

@@ -531,3 +531,11 @@ def throttle(interval: float = 10.0):
         return wrapper
 
     return decorator
+
+
+def extract_urls(text: str) -> list[str]:
+    # Regex pattern to match URLs
+    url_pattern = re.compile(
+        r"(https?://[^\s]+)", re.IGNORECASE
+    )  # Matches http and https URLs
+    return url_pattern.findall(text)