|
@@ -16,6 +16,8 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
|
|
from open_webui.models.users import UserModel
|
|
|
from open_webui.models.files import Files
|
|
|
|
|
|
+from open_webui.retrieval.vector.main import GetResult
|
|
|
+
|
|
|
from open_webui.env import (
|
|
|
SRC_LOG_LEVELS,
|
|
|
OFFLINE_MODE,
|
|
@@ -98,7 +100,7 @@ def get_doc(collection_name: str, user: UserModel = None):
|
|
|
|
|
|
def query_doc_with_hybrid_search(
|
|
|
collection_name: str,
|
|
|
- collection_data,
|
|
|
+ collection_result: GetResult,
|
|
|
query: str,
|
|
|
embedding_function,
|
|
|
k: int,
|
|
@@ -108,8 +110,8 @@ def query_doc_with_hybrid_search(
|
|
|
) -> dict:
|
|
|
try:
|
|
|
bm25_retriever = BM25Retriever.from_texts(
|
|
|
- texts=collection_data.documents[0],
|
|
|
- metadatas=collection_data.metadatas[0],
|
|
|
+ texts=collection_result.documents[0],
|
|
|
+ metadatas=collection_result.metadatas[0],
|
|
|
)
|
|
|
bm25_retriever.k = k
|
|
|
|
|
@@ -135,9 +137,9 @@ def query_doc_with_hybrid_search(
|
|
|
|
|
|
result = compression_retriever.invoke(query)
|
|
|
|
|
|
- distances = [d.metadata.get("score") for d in collection_data]
|
|
|
- documents = [d.page_content for d in collection_data]
|
|
|
- metadatas = [d.metadata for d in collection_data]
|
|
|
+ distances = [d.metadata.get("score") for d in result]
|
|
|
+ documents = [d.page_content for d in result]
|
|
|
+ metadatas = [d.metadata for d in result]
|
|
|
|
|
|
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
|
|
|
if k < k_reranker:
|
|
@@ -146,7 +148,8 @@ def query_doc_with_hybrid_search(
|
|
|
)
|
|
|
sorted_items = sorted_items[:k]
|
|
|
distances, documents, metadatas = map(list, zip(*sorted_items))
|
|
|
- collection_data = {
|
|
|
+
|
|
|
+ result = {
|
|
|
"distances": [distances],
|
|
|
"documents": [documents],
|
|
|
"metadatas": [metadatas],
|
|
@@ -154,9 +157,9 @@ def query_doc_with_hybrid_search(
|
|
|
|
|
|
log.info(
|
|
|
"query_doc_with_hybrid_search:result "
|
|
|
- + f'{collection_data["metadatas"]} {collection_data["distances"]}'
|
|
|
+ + f'{result["metadatas"]} {result["distances"]}'
|
|
|
)
|
|
|
- return collection_data
|
|
|
+ return result
|
|
|
except Exception as e:
|
|
|
raise e
|
|
|
|
|
@@ -279,20 +282,22 @@ def query_collection_with_hybrid_search(
|
|
|
error = False
|
|
|
# Fetch collection data once per collection sequentially
|
|
|
# Avoid fetching the same data multiple times later
|
|
|
- collection_data = {}
|
|
|
+ collection_results = {}
|
|
|
for collection_name in collection_names:
|
|
|
try:
|
|
|
- collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
|
|
+ collection_results[collection_name] = VECTOR_DB_CLIENT.get(
|
|
|
+ collection_name=collection_name
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
log.exception(f"Failed to fetch collection {collection_name}: {e}")
|
|
|
- collection_data[collection_name] = None
|
|
|
+ collection_results[collection_name] = None
|
|
|
|
|
|
for collection_name in collection_names:
|
|
|
try:
|
|
|
for query in queries:
|
|
|
result = query_doc_with_hybrid_search(
|
|
|
collection_name=collection_name,
|
|
|
- collection_data=collection_data[collection_name],
|
|
|
+ collection_result=collection_results[collection_name],
|
|
|
query=query,
|
|
|
embedding_function=embedding_function,
|
|
|
k=k,
|