utils.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. import requests
  5. import hashlib
  6. from huggingface_hub import snapshot_download
  7. from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
  8. from langchain_community.retrievers import BM25Retriever
  9. from langchain_core.documents import Document
  10. from open_webui.config import VECTOR_DB
  11. from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
  12. from open_webui.models.users import UserModel
  13. from open_webui.models.files import Files
  14. from open_webui.retrieval.vector.main import GetResult
  15. from open_webui.env import (
  16. SRC_LOG_LEVELS,
  17. OFFLINE_MODE,
  18. ENABLE_FORWARD_USER_INFO_HEADERS,
  19. )
  20. from open_webui.config import (
  21. RAG_EMBEDDING_QUERY_PREFIX,
  22. RAG_EMBEDDING_CONTENT_PREFIX,
  23. RAG_EMBEDDING_PREFIX_FIELD_NAME,
  24. )
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. from typing import Any
  28. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  29. from langchain_core.retrievers import BaseRetriever
  30. class VectorSearchRetriever(BaseRetriever):
  31. collection_name: Any
  32. embedding_function: Any
  33. top_k: int
  34. def _get_relevant_documents(
  35. self,
  36. query: str,
  37. *,
  38. run_manager: CallbackManagerForRetrieverRun,
  39. ) -> list[Document]:
  40. result = VECTOR_DB_CLIENT.search(
  41. collection_name=self.collection_name,
  42. vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
  43. limit=self.top_k,
  44. )
  45. ids = result.ids[0]
  46. metadatas = result.metadatas[0]
  47. documents = result.documents[0]
  48. results = []
  49. for idx in range(len(ids)):
  50. results.append(
  51. Document(
  52. metadata=metadatas[idx],
  53. page_content=documents[idx],
  54. )
  55. )
  56. return results
  57. def query_doc(
  58. collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
  59. ):
  60. try:
  61. result = VECTOR_DB_CLIENT.search(
  62. collection_name=collection_name,
  63. vectors=[query_embedding],
  64. limit=k,
  65. )
  66. if result:
  67. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  68. return result
  69. except Exception as e:
  70. log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
  71. raise e
  72. def get_doc(collection_name: str, user: UserModel = None):
  73. try:
  74. result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
  75. if result:
  76. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  77. return result
  78. except Exception as e:
  79. log.exception(f"Error getting doc {collection_name}: {e}")
  80. raise e
  81. def query_doc_with_hybrid_search(
  82. collection_name: str,
  83. collection_result: GetResult,
  84. query: str,
  85. embedding_function,
  86. k: int,
  87. reranking_function,
  88. k_reranker: int,
  89. r: float,
  90. ) -> dict:
  91. try:
  92. bm25_retriever = BM25Retriever.from_texts(
  93. texts=collection_result.documents[0],
  94. metadatas=collection_result.metadatas[0],
  95. )
  96. bm25_retriever.k = k
  97. vector_search_retriever = VectorSearchRetriever(
  98. collection_name=collection_name,
  99. embedding_function=embedding_function,
  100. top_k=k,
  101. )
  102. ensemble_retriever = EnsembleRetriever(
  103. retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
  104. )
  105. compressor = RerankCompressor(
  106. embedding_function=embedding_function,
  107. top_n=k_reranker,
  108. reranking_function=reranking_function,
  109. r_score=r,
  110. )
  111. compression_retriever = ContextualCompressionRetriever(
  112. base_compressor=compressor, base_retriever=ensemble_retriever
  113. )
  114. result = compression_retriever.invoke(query)
  115. distances = [d.metadata.get("score") for d in result]
  116. documents = [d.page_content for d in result]
  117. metadatas = [d.metadata for d in result]
  118. # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
  119. if k < k_reranker:
  120. sorted_items = sorted(
  121. zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
  122. )
  123. sorted_items = sorted_items[:k]
  124. distances, documents, metadatas = map(list, zip(*sorted_items))
  125. result = {
  126. "distances": [distances],
  127. "documents": [documents],
  128. "metadatas": [metadatas],
  129. }
  130. log.info(
  131. "query_doc_with_hybrid_search:result "
  132. + f'{result["metadatas"]} {result["distances"]}'
  133. )
  134. return result
  135. except Exception as e:
  136. raise e
  137. def merge_get_results(get_results: list[dict]) -> dict:
  138. # Initialize lists to store combined data
  139. combined_documents = []
  140. combined_metadatas = []
  141. combined_ids = []
  142. for data in get_results:
  143. combined_documents.extend(data["documents"][0])
  144. combined_metadatas.extend(data["metadatas"][0])
  145. combined_ids.extend(data["ids"][0])
  146. # Create the output dictionary
  147. result = {
  148. "documents": [combined_documents],
  149. "metadatas": [combined_metadatas],
  150. "ids": [combined_ids],
  151. }
  152. return result
  153. def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
  154. # Initialize lists to store combined data
  155. combined = dict() # To store documents with unique document hashes
  156. for data in query_results:
  157. distances = data["distances"][0]
  158. documents = data["documents"][0]
  159. metadatas = data["metadatas"][0]
  160. for distance, document, metadata in zip(distances, documents, metadatas):
  161. if isinstance(document, str):
  162. doc_hash = hashlib.md5(
  163. document.encode()
  164. ).hexdigest() # Compute a hash for uniqueness
  165. if doc_hash not in combined.keys():
  166. combined[doc_hash] = (distance, document, metadata)
  167. continue # if doc is new, no further comparison is needed
  168. # if doc is alredy in, but new distance is better, update
  169. if distance > combined[doc_hash][0]:
  170. combined[doc_hash] = (distance, document, metadata)
  171. combined = list(combined.values())
  172. # Sort the list based on distances
  173. combined.sort(key=lambda x: x[0], reverse=True)
  174. # Slice to keep only the top k elements
  175. sorted_distances, sorted_documents, sorted_metadatas = (
  176. zip(*combined[:k]) if combined else ([], [], [])
  177. )
  178. # Create and return the output dictionary
  179. return {
  180. "distances": [list(sorted_distances)],
  181. "documents": [list(sorted_documents)],
  182. "metadatas": [list(sorted_metadatas)],
  183. }
  184. def get_all_items_from_collections(collection_names: list[str]) -> dict:
  185. results = []
  186. for collection_name in collection_names:
  187. if collection_name:
  188. try:
  189. result = get_doc(collection_name=collection_name)
  190. if result is not None:
  191. results.append(result.model_dump())
  192. except Exception as e:
  193. log.exception(f"Error when querying the collection: {e}")
  194. else:
  195. pass
  196. return merge_get_results(results)
  197. def query_collection(
  198. collection_names: list[str],
  199. queries: list[str],
  200. embedding_function,
  201. k: int,
  202. ) -> dict:
  203. results = []
  204. for query in queries:
  205. query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
  206. for collection_name in collection_names:
  207. if collection_name:
  208. try:
  209. result = query_doc(
  210. collection_name=collection_name,
  211. k=k,
  212. query_embedding=query_embedding,
  213. )
  214. if result is not None:
  215. results.append(result.model_dump())
  216. except Exception as e:
  217. log.exception(f"Error when querying the collection: {e}")
  218. else:
  219. pass
  220. return merge_and_sort_query_results(results, k=k)
  221. def query_collection_with_hybrid_search(
  222. collection_names: list[str],
  223. queries: list[str],
  224. embedding_function,
  225. k: int,
  226. reranking_function,
  227. k_reranker: int,
  228. r: float,
  229. ) -> dict:
  230. results = []
  231. error = False
  232. # Fetch collection data once per collection sequentially
  233. # Avoid fetching the same data multiple times later
  234. collection_results = {}
  235. for collection_name in collection_names:
  236. try:
  237. collection_results[collection_name] = VECTOR_DB_CLIENT.get(
  238. collection_name=collection_name
  239. )
  240. except Exception as e:
  241. log.exception(f"Failed to fetch collection {collection_name}: {e}")
  242. collection_results[collection_name] = None
  243. for collection_name in collection_names:
  244. try:
  245. for query in queries:
  246. result = query_doc_with_hybrid_search(
  247. collection_name=collection_name,
  248. collection_result=collection_results[collection_name],
  249. query=query,
  250. embedding_function=embedding_function,
  251. k=k,
  252. reranking_function=reranking_function,
  253. k_reranker=k_reranker,
  254. r=r,
  255. )
  256. results.append(result)
  257. except Exception as e:
  258. log.exception(
  259. "Error when querying the collection with " f"hybrid_search: {e}"
  260. )
  261. error = True
  262. if error:
  263. raise Exception(
  264. "Hybrid search failed for all collections. Using Non hybrid search as fallback."
  265. )
  266. return merge_and_sort_query_results(results, k=k)
  267. def get_embedding_function(
  268. embedding_engine,
  269. embedding_model,
  270. embedding_function,
  271. url,
  272. key,
  273. embedding_batch_size,
  274. ):
  275. if embedding_engine == "":
  276. return lambda query, prefix=None, user=None: embedding_function.encode(
  277. query, prompt=prefix if prefix else None
  278. ).tolist()
  279. elif embedding_engine in ["ollama", "openai"]:
  280. func = lambda query, prefix=None, user=None: generate_embeddings(
  281. engine=embedding_engine,
  282. model=embedding_model,
  283. text=query,
  284. prefix=prefix,
  285. url=url,
  286. key=key,
  287. user=user,
  288. )
  289. def generate_multiple(query, prefix, user, func):
  290. if isinstance(query, list):
  291. embeddings = []
  292. for i in range(0, len(query), embedding_batch_size):
  293. embeddings.extend(
  294. func(
  295. query[i : i + embedding_batch_size],
  296. prefix=prefix,
  297. user=user,
  298. )
  299. )
  300. return embeddings
  301. else:
  302. return func(query, prefix, user)
  303. return lambda query, prefix=None, user=None: generate_multiple(
  304. query, prefix, user, func
  305. )
  306. else:
  307. raise ValueError(f"Unknown embedding engine: {embedding_engine}")
  308. def get_sources_from_files(
  309. request,
  310. files,
  311. queries,
  312. embedding_function,
  313. k,
  314. reranking_function,
  315. k_reranker,
  316. r,
  317. hybrid_search,
  318. full_context=False,
  319. ):
  320. log.debug(
  321. f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
  322. )
  323. extracted_collections = []
  324. relevant_contexts = []
  325. for file in files:
  326. context = None
  327. if file.get("docs"):
  328. # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
  329. context = {
  330. "documents": [[doc.get("content") for doc in file.get("docs")]],
  331. "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
  332. }
  333. elif file.get("context") == "full":
  334. # Manual Full Mode Toggle
  335. context = {
  336. "documents": [[file.get("file").get("data", {}).get("content")]],
  337. "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
  338. }
  339. elif (
  340. file.get("type") != "web_search"
  341. and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
  342. ):
  343. # BYPASS_EMBEDDING_AND_RETRIEVAL
  344. if file.get("type") == "collection":
  345. file_ids = file.get("data", {}).get("file_ids", [])
  346. documents = []
  347. metadatas = []
  348. for file_id in file_ids:
  349. file_object = Files.get_file_by_id(file_id)
  350. if file_object:
  351. documents.append(file_object.data.get("content", ""))
  352. metadatas.append(
  353. {
  354. "file_id": file_id,
  355. "name": file_object.filename,
  356. "source": file_object.filename,
  357. }
  358. )
  359. context = {
  360. "documents": [documents],
  361. "metadatas": [metadatas],
  362. }
  363. elif file.get("id"):
  364. file_object = Files.get_file_by_id(file.get("id"))
  365. if file_object:
  366. context = {
  367. "documents": [[file_object.data.get("content", "")]],
  368. "metadatas": [
  369. [
  370. {
  371. "file_id": file.get("id"),
  372. "name": file_object.filename,
  373. "source": file_object.filename,
  374. }
  375. ]
  376. ],
  377. }
  378. elif file.get("file").get("data"):
  379. context = {
  380. "documents": [[file.get("file").get("data", {}).get("content")]],
  381. "metadatas": [
  382. [file.get("file").get("data", {}).get("metadata", {})]
  383. ],
  384. }
  385. else:
  386. collection_names = []
  387. if file.get("type") == "collection":
  388. if file.get("legacy"):
  389. collection_names = file.get("collection_names", [])
  390. else:
  391. collection_names.append(file["id"])
  392. elif file.get("collection_name"):
  393. collection_names.append(file["collection_name"])
  394. elif file.get("id"):
  395. if file.get("legacy"):
  396. collection_names.append(f"{file['id']}")
  397. else:
  398. collection_names.append(f"file-{file['id']}")
  399. collection_names = set(collection_names).difference(extracted_collections)
  400. if not collection_names:
  401. log.debug(f"skipping {file} as it has already been extracted")
  402. continue
  403. if full_context:
  404. try:
  405. context = get_all_items_from_collections(collection_names)
  406. except Exception as e:
  407. log.exception(e)
  408. else:
  409. try:
  410. context = None
  411. if file.get("type") == "text":
  412. context = file["content"]
  413. else:
  414. if hybrid_search:
  415. try:
  416. context = query_collection_with_hybrid_search(
  417. collection_names=collection_names,
  418. queries=queries,
  419. embedding_function=embedding_function,
  420. k=k,
  421. reranking_function=reranking_function,
  422. k_reranker=k_reranker,
  423. r=r,
  424. )
  425. except Exception as e:
  426. log.debug(
  427. "Error when using hybrid search, using"
  428. " non hybrid search as fallback."
  429. )
  430. if (not hybrid_search) or (context is None):
  431. context = query_collection(
  432. collection_names=collection_names,
  433. queries=queries,
  434. embedding_function=embedding_function,
  435. k=k,
  436. )
  437. except Exception as e:
  438. log.exception(e)
  439. extracted_collections.extend(collection_names)
  440. if context:
  441. if "data" in file:
  442. del file["data"]
  443. relevant_contexts.append({**context, "file": file})
  444. sources = []
  445. for context in relevant_contexts:
  446. try:
  447. if "documents" in context:
  448. if "metadatas" in context:
  449. source = {
  450. "source": context["file"],
  451. "document": context["documents"][0],
  452. "metadata": context["metadatas"][0],
  453. }
  454. if "distances" in context and context["distances"]:
  455. source["distances"] = context["distances"][0]
  456. sources.append(source)
  457. except Exception as e:
  458. log.exception(e)
  459. return sources
  460. def get_model_path(model: str, update_model: bool = False):
  461. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  462. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  463. local_files_only = not update_model
  464. if OFFLINE_MODE:
  465. local_files_only = True
  466. snapshot_kwargs = {
  467. "cache_dir": cache_dir,
  468. "local_files_only": local_files_only,
  469. }
  470. log.debug(f"model: {model}")
  471. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  472. # Inspiration from upstream sentence_transformers
  473. if (
  474. os.path.exists(model)
  475. or ("\\" in model or model.count("/") > 1)
  476. and local_files_only
  477. ):
  478. # If fully qualified path exists, return input, else set repo_id
  479. return model
  480. elif "/" not in model:
  481. # Set valid repo_id for model short-name
  482. model = "sentence-transformers" + "/" + model
  483. snapshot_kwargs["repo_id"] = model
  484. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  485. try:
  486. model_repo_path = snapshot_download(**snapshot_kwargs)
  487. log.debug(f"model_repo_path: {model_repo_path}")
  488. return model_repo_path
  489. except Exception as e:
  490. log.exception(f"Cannot determine model snapshot path: {e}")
  491. return model
  492. def generate_openai_batch_embeddings(
  493. model: str,
  494. texts: list[str],
  495. url: str = "https://api.openai.com/v1",
  496. key: str = "",
  497. prefix: str = None,
  498. user: UserModel = None,
  499. ) -> Optional[list[list[float]]]:
  500. try:
  501. json_data = {"input": texts, "model": model}
  502. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  503. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  504. r = requests.post(
  505. f"{url}/embeddings",
  506. headers={
  507. "Content-Type": "application/json",
  508. "Authorization": f"Bearer {key}",
  509. **(
  510. {
  511. "X-OpenWebUI-User-Name": user.name,
  512. "X-OpenWebUI-User-Id": user.id,
  513. "X-OpenWebUI-User-Email": user.email,
  514. "X-OpenWebUI-User-Role": user.role,
  515. }
  516. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  517. else {}
  518. ),
  519. },
  520. json=json_data,
  521. )
  522. r.raise_for_status()
  523. data = r.json()
  524. if "data" in data:
  525. return [elem["embedding"] for elem in data["data"]]
  526. else:
  527. raise "Something went wrong :/"
  528. except Exception as e:
  529. log.exception(f"Error generating openai batch embeddings: {e}")
  530. return None
  531. def generate_ollama_batch_embeddings(
  532. model: str,
  533. texts: list[str],
  534. url: str,
  535. key: str = "",
  536. prefix: str = None,
  537. user: UserModel = None,
  538. ) -> Optional[list[list[float]]]:
  539. try:
  540. json_data = {"input": texts, "model": model}
  541. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  542. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  543. r = requests.post(
  544. f"{url}/api/embed",
  545. headers={
  546. "Content-Type": "application/json",
  547. "Authorization": f"Bearer {key}",
  548. **(
  549. {
  550. "X-OpenWebUI-User-Name": user.name,
  551. "X-OpenWebUI-User-Id": user.id,
  552. "X-OpenWebUI-User-Email": user.email,
  553. "X-OpenWebUI-User-Role": user.role,
  554. }
  555. if ENABLE_FORWARD_USER_INFO_HEADERS
  556. else {}
  557. ),
  558. },
  559. json=json_data,
  560. )
  561. r.raise_for_status()
  562. data = r.json()
  563. if "embeddings" in data:
  564. return data["embeddings"]
  565. else:
  566. raise "Something went wrong :/"
  567. except Exception as e:
  568. log.exception(f"Error generating ollama batch embeddings: {e}")
  569. return None
  570. def generate_embeddings(
  571. engine: str,
  572. model: str,
  573. text: Union[str, list[str]],
  574. prefix: Union[str, None] = None,
  575. **kwargs,
  576. ):
  577. url = kwargs.get("url", "")
  578. key = kwargs.get("key", "")
  579. user = kwargs.get("user")
  580. if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
  581. if isinstance(text, list):
  582. text = [f"{prefix}{text_element}" for text_element in text]
  583. else:
  584. text = f"{prefix}{text}"
  585. if engine == "ollama":
  586. if isinstance(text, list):
  587. embeddings = generate_ollama_batch_embeddings(
  588. **{
  589. "model": model,
  590. "texts": text,
  591. "url": url,
  592. "key": key,
  593. "prefix": prefix,
  594. "user": user,
  595. }
  596. )
  597. else:
  598. embeddings = generate_ollama_batch_embeddings(
  599. **{
  600. "model": model,
  601. "texts": [text],
  602. "url": url,
  603. "key": key,
  604. "prefix": prefix,
  605. "user": user,
  606. }
  607. )
  608. return embeddings[0] if isinstance(text, str) else embeddings
  609. elif engine == "openai":
  610. if isinstance(text, list):
  611. embeddings = generate_openai_batch_embeddings(
  612. model, text, url, key, prefix, user
  613. )
  614. else:
  615. embeddings = generate_openai_batch_embeddings(
  616. model, [text], url, key, prefix, user
  617. )
  618. return embeddings[0] if isinstance(text, str) else embeddings
  619. import operator
  620. from typing import Optional, Sequence
  621. from langchain_core.callbacks import Callbacks
  622. from langchain_core.documents import BaseDocumentCompressor, Document
  623. class RerankCompressor(BaseDocumentCompressor):
  624. embedding_function: Any
  625. top_n: int
  626. reranking_function: Any
  627. r_score: float
  628. class Config:
  629. extra = "forbid"
  630. arbitrary_types_allowed = True
  631. def compress_documents(
  632. self,
  633. documents: Sequence[Document],
  634. query: str,
  635. callbacks: Optional[Callbacks] = None,
  636. ) -> Sequence[Document]:
  637. reranking = self.reranking_function is not None
  638. if reranking:
  639. scores = self.reranking_function.predict(
  640. [(query, doc.page_content) for doc in documents]
  641. )
  642. else:
  643. from sentence_transformers import util
  644. query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
  645. document_embedding = self.embedding_function(
  646. [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
  647. )
  648. scores = util.cos_sim(query_embedding, document_embedding)[0]
  649. docs_with_scores = list(zip(documents, scores.tolist()))
  650. if self.r_score:
  651. docs_with_scores = [
  652. (d, s) for d, s in docs_with_scores if s >= self.r_score
  653. ]
  654. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
  655. final_results = []
  656. for doc, doc_score in result[: self.top_n]:
  657. metadata = doc.metadata
  658. metadata["score"] = doc_score
  659. doc = Document(
  660. page_content=doc.page_content,
  661. metadata=metadata,
  662. )
  663. final_results.append(doc)
  664. return final_results