|
@@ -3,6 +3,8 @@ import logging
|
|
import mimetypes
|
|
import mimetypes
|
|
import os
|
|
import os
|
|
import shutil
|
|
import shutil
|
|
|
|
+import asyncio
|
|
|
|
+
|
|
|
|
|
|
import uuid
|
|
import uuid
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
@@ -188,7 +190,7 @@ class ProcessUrlForm(CollectionNameForm):
|
|
|
|
|
|
|
|
|
|
class SearchForm(BaseModel):
|
|
class SearchForm(BaseModel):
|
|
- query: str
|
|
|
|
|
|
+ queries: List[str]
|
|
|
|
|
|
|
|
|
|
@router.get("/")
|
|
@router.get("/")
|
|
@@ -1568,16 +1570,34 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
|
async def process_web_search(
|
|
async def process_web_search(
|
|
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
|
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
|
|
):
|
|
):
|
|
|
|
+
|
|
|
|
+ urls = []
|
|
try:
|
|
try:
|
|
logging.info(
|
|
logging.info(
|
|
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
|
|
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
|
|
)
|
|
)
|
|
- web_results = await run_in_threadpool(
|
|
|
|
- search_web,
|
|
|
|
- request,
|
|
|
|
- request.app.state.config.WEB_SEARCH_ENGINE,
|
|
|
|
- form_data.query,
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ search_tasks = [
|
|
|
|
+ run_in_threadpool(
|
|
|
|
+ search_web,
|
|
|
|
+ request,
|
|
|
|
+ request.app.state.config.WEB_SEARCH_ENGINE,
|
|
|
|
+ query,
|
|
|
|
+ )
|
|
|
|
+ for query in form_data.queries
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ search_results = await asyncio.gather(*search_tasks)
|
|
|
|
+
|
|
|
|
+ for result in search_results:
|
|
|
|
+ if result:
|
|
|
|
+ for item in result:
|
|
|
|
+ if item and item.link:
|
|
|
|
+ urls.append(item.link)
|
|
|
|
+
|
|
|
|
+ urls = list(dict.fromkeys(urls))
|
|
|
|
+ log.debug(f"urls: {urls}")
|
|
|
|
+
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.exception(e)
|
|
log.exception(e)
|
|
|
|
|
|
@@ -1586,15 +1606,7 @@ async def process_web_search(
|
|
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
|
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
|
)
|
|
)
|
|
|
|
|
|
- log.debug(f"web_results: {web_results}")
|
|
|
|
-
|
|
|
|
try:
|
|
try:
|
|
- urls = [result.link for result in web_results]
|
|
|
|
-
|
|
|
|
- # Remove duplicates
|
|
|
|
- urls = list(dict.fromkeys(urls))
|
|
|
|
- log.debug(f"urls: {urls}")
|
|
|
|
-
|
|
|
|
loader = get_web_loader(
|
|
loader = get_web_loader(
|
|
urls,
|
|
urls,
|
|
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
@@ -1604,7 +1616,7 @@ async def process_web_search(
|
|
docs = await loader.aload()
|
|
docs = await loader.aload()
|
|
urls = [
|
|
urls = [
|
|
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
|
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
|
- ] # only keep URLs
|
|
|
|
|
|
+ ] # only keep the urls returned by the loader
|
|
|
|
|
|
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
|
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
|
return {
|
|
return {
|
|
@@ -1621,29 +1633,28 @@ async def process_web_search(
|
|
"loaded_count": len(docs),
|
|
"loaded_count": len(docs),
|
|
}
|
|
}
|
|
else:
|
|
else:
|
|
- collection_names = []
|
|
|
|
- for doc_idx, doc in enumerate(docs):
|
|
|
|
- if doc and doc.page_content:
|
|
|
|
- try:
|
|
|
|
- collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
|
|
|
|
- :63
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- collection_names.append(collection_name)
|
|
|
|
- await run_in_threadpool(
|
|
|
|
- save_docs_to_vector_db,
|
|
|
|
- request,
|
|
|
|
- [doc],
|
|
|
|
- collection_name,
|
|
|
|
- overwrite=True,
|
|
|
|
- user=user,
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- log.debug(f"error saving doc {doc_idx}: {e}")
|
|
|
|
|
|
+ # Create a single collection for all documents
|
|
|
|
+ collection_name = (
|
|
|
|
+ f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
|
|
|
|
+ :63
|
|
|
|
+ ]
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ await run_in_threadpool(
|
|
|
|
+ save_docs_to_vector_db,
|
|
|
|
+ request,
|
|
|
|
+ docs,
|
|
|
|
+ collection_name,
|
|
|
|
+ overwrite=True,
|
|
|
|
+ user=user,
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.debug(f"error saving docs: {e}")
|
|
|
|
|
|
return {
|
|
return {
|
|
"status": True,
|
|
"status": True,
|
|
- "collection_names": collection_names,
|
|
|
|
|
|
+ "collection_names": [collection_name],
|
|
"filenames": urls,
|
|
"filenames": urls,
|
|
"loaded_count": len(docs),
|
|
"loaded_count": len(docs),
|
|
}
|
|
}
|