utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. import requests
  5. import hashlib
  6. from concurrent.futures import ThreadPoolExecutor
  7. from huggingface_hub import snapshot_download
  8. from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
  9. from langchain_community.retrievers import BM25Retriever
  10. from langchain_core.documents import Document
  11. from open_webui.config import VECTOR_DB
  12. from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
  13. from open_webui.models.users import UserModel
  14. from open_webui.models.files import Files
  15. from open_webui.retrieval.vector.main import GetResult
  16. from open_webui.env import (
  17. SRC_LOG_LEVELS,
  18. OFFLINE_MODE,
  19. ENABLE_FORWARD_USER_INFO_HEADERS,
  20. )
  21. from open_webui.config import (
  22. RAG_EMBEDDING_QUERY_PREFIX,
  23. RAG_EMBEDDING_CONTENT_PREFIX,
  24. RAG_EMBEDDING_PREFIX_FIELD_NAME,
  25. RAG_BM25_WEIGHT,
  26. )
  27. log = logging.getLogger(__name__)
  28. log.setLevel(SRC_LOG_LEVELS["RAG"])
  29. from typing import Any
  30. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  31. from langchain_core.retrievers import BaseRetriever
  32. class VectorSearchRetriever(BaseRetriever):
  33. collection_name: Any
  34. embedding_function: Any
  35. top_k: int
  36. def _get_relevant_documents(
  37. self,
  38. query: str,
  39. *,
  40. run_manager: CallbackManagerForRetrieverRun,
  41. ) -> list[Document]:
  42. result = VECTOR_DB_CLIENT.search(
  43. collection_name=self.collection_name,
  44. vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
  45. limit=self.top_k,
  46. )
  47. ids = result.ids[0]
  48. metadatas = result.metadatas[0]
  49. documents = result.documents[0]
  50. results = []
  51. for idx in range(len(ids)):
  52. results.append(
  53. Document(
  54. metadata=metadatas[idx],
  55. page_content=documents[idx],
  56. )
  57. )
  58. return results
  59. def query_doc(
  60. collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
  61. ):
  62. try:
  63. log.debug(f"query_doc:doc {collection_name}")
  64. result = VECTOR_DB_CLIENT.search(
  65. collection_name=collection_name,
  66. vectors=[query_embedding],
  67. limit=k,
  68. )
  69. if result:
  70. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  71. return result
  72. except Exception as e:
  73. log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
  74. raise e
  75. def get_doc(collection_name: str, user: UserModel = None):
  76. try:
  77. log.debug(f"get_doc:doc {collection_name}")
  78. result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
  79. if result:
  80. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  81. return result
  82. except Exception as e:
  83. log.exception(f"Error getting doc {collection_name}: {e}")
  84. raise e
  85. def query_doc_with_hybrid_search(
  86. collection_name: str,
  87. collection_result: GetResult,
  88. query: str,
  89. embedding_function,
  90. k: int,
  91. reranking_function,
  92. k_reranker: int,
  93. r: float,
  94. ) -> dict:
  95. try:
  96. log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
  97. bm25_retriever = BM25Retriever.from_texts(
  98. texts=collection_result.documents[0],
  99. metadatas=collection_result.metadatas[0],
  100. )
  101. bm25_retriever.k = k
  102. vector_search_retriever = VectorSearchRetriever(
  103. collection_name=collection_name,
  104. embedding_function=embedding_function,
  105. top_k=k,
  106. )
  107. if RAG_BM25_WEIGHT <= 0:
  108. ensemble_retriever = EnsembleRetriever(
  109. retrievers=[vector_search_retriever], weights=[1.]
  110. )
  111. elif RAG_BM25_WEIGHT >= 1:
  112. ensemble_retriever = EnsembleRetriever(
  113. retrievers=[bm25_retriever], weights=[1.]
  114. )
  115. else:
  116. ensemble_retriever = EnsembleRetriever(
  117. retrievers=[bm25_retriever, vector_search_retriever],
  118. weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT]
  119. )
  120. compressor = RerankCompressor(
  121. embedding_function=embedding_function,
  122. top_n=k_reranker,
  123. reranking_function=reranking_function,
  124. r_score=r,
  125. )
  126. compression_retriever = ContextualCompressionRetriever(
  127. base_compressor=compressor, base_retriever=ensemble_retriever
  128. )
  129. result = compression_retriever.invoke(query)
  130. distances = [d.metadata.get("score") for d in result]
  131. documents = [d.page_content for d in result]
  132. metadatas = [d.metadata for d in result]
  133. # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
  134. if k < k_reranker:
  135. sorted_items = sorted(
  136. zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
  137. )
  138. sorted_items = sorted_items[:k]
  139. distances, documents, metadatas = map(list, zip(*sorted_items))
  140. result = {
  141. "distances": [distances],
  142. "documents": [documents],
  143. "metadatas": [metadatas],
  144. }
  145. log.info(
  146. "query_doc_with_hybrid_search:result "
  147. + f'{result["metadatas"]} {result["distances"]}'
  148. )
  149. return result
  150. except Exception as e:
  151. log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
  152. raise e
  153. def merge_get_results(get_results: list[dict]) -> dict:
  154. # Initialize lists to store combined data
  155. combined_documents = []
  156. combined_metadatas = []
  157. combined_ids = []
  158. for data in get_results:
  159. combined_documents.extend(data["documents"][0])
  160. combined_metadatas.extend(data["metadatas"][0])
  161. combined_ids.extend(data["ids"][0])
  162. # Create the output dictionary
  163. result = {
  164. "documents": [combined_documents],
  165. "metadatas": [combined_metadatas],
  166. "ids": [combined_ids],
  167. }
  168. return result
  169. def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
  170. # Initialize lists to store combined data
  171. combined = dict() # To store documents with unique document hashes
  172. for data in query_results:
  173. distances = data["distances"][0]
  174. documents = data["documents"][0]
  175. metadatas = data["metadatas"][0]
  176. for distance, document, metadata in zip(distances, documents, metadatas):
  177. if isinstance(document, str):
  178. doc_hash = hashlib.sha256(
  179. document.encode()
  180. ).hexdigest() # Compute a hash for uniqueness
  181. if doc_hash not in combined.keys():
  182. combined[doc_hash] = (distance, document, metadata)
  183. continue # if doc is new, no further comparison is needed
  184. # if doc is alredy in, but new distance is better, update
  185. if distance > combined[doc_hash][0]:
  186. combined[doc_hash] = (distance, document, metadata)
  187. combined = list(combined.values())
  188. # Sort the list based on distances
  189. combined.sort(key=lambda x: x[0], reverse=True)
  190. # Slice to keep only the top k elements
  191. sorted_distances, sorted_documents, sorted_metadatas = (
  192. zip(*combined[:k]) if combined else ([], [], [])
  193. )
  194. # Create and return the output dictionary
  195. return {
  196. "distances": [list(sorted_distances)],
  197. "documents": [list(sorted_documents)],
  198. "metadatas": [list(sorted_metadatas)],
  199. }
  200. def get_all_items_from_collections(collection_names: list[str]) -> dict:
  201. results = []
  202. for collection_name in collection_names:
  203. if collection_name:
  204. try:
  205. result = get_doc(collection_name=collection_name)
  206. if result is not None:
  207. results.append(result.model_dump())
  208. except Exception as e:
  209. log.exception(f"Error when querying the collection: {e}")
  210. else:
  211. pass
  212. return merge_get_results(results)
  213. def query_collection(
  214. collection_names: list[str],
  215. queries: list[str],
  216. embedding_function,
  217. k: int,
  218. ) -> dict:
  219. results = []
  220. error = False
  221. def process_query_collection(collection_name, query_embedding):
  222. try:
  223. if collection_name:
  224. result = query_doc(
  225. collection_name=collection_name,
  226. k=k,
  227. query_embedding=query_embedding,
  228. )
  229. if result is not None:
  230. return result.model_dump(), None
  231. return None, None
  232. except Exception as e:
  233. log.exception(f"Error when querying the collection: {e}")
  234. return None, e
  235. # Generate all query embeddings (in one call)
  236. query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
  237. log.debug(
  238. f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
  239. )
  240. with ThreadPoolExecutor() as executor:
  241. future_results = []
  242. for query_embedding in query_embeddings:
  243. for collection_name in collection_names:
  244. result = executor.submit(
  245. process_query_collection, collection_name, query_embedding
  246. )
  247. future_results.append(result)
  248. task_results = [future.result() for future in future_results]
  249. for result, err in task_results:
  250. if err is not None:
  251. error = True
  252. elif result is not None:
  253. results.append(result)
  254. if error and not results:
  255. log.warning("All collection queries failed. No results returned.")
  256. return merge_and_sort_query_results(results, k=k)
  257. def query_collection_with_hybrid_search(
  258. collection_names: list[str],
  259. queries: list[str],
  260. embedding_function,
  261. k: int,
  262. reranking_function,
  263. k_reranker: int,
  264. r: float,
  265. ) -> dict:
  266. results = []
  267. error = False
  268. # Fetch collection data once per collection sequentially
  269. # Avoid fetching the same data multiple times later
  270. collection_results = {}
  271. for collection_name in collection_names:
  272. try:
  273. log.debug(
  274. f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
  275. )
  276. collection_results[collection_name] = VECTOR_DB_CLIENT.get(
  277. collection_name=collection_name
  278. )
  279. except Exception as e:
  280. log.exception(f"Failed to fetch collection {collection_name}: {e}")
  281. collection_results[collection_name] = None
  282. log.info(
  283. f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
  284. )
  285. def process_query(collection_name, query):
  286. try:
  287. result = query_doc_with_hybrid_search(
  288. collection_name=collection_name,
  289. collection_result=collection_results[collection_name],
  290. query=query,
  291. embedding_function=embedding_function,
  292. k=k,
  293. reranking_function=reranking_function,
  294. k_reranker=k_reranker,
  295. r=r,
  296. )
  297. return result, None
  298. except Exception as e:
  299. log.exception(f"Error when querying the collection with hybrid_search: {e}")
  300. return None, e
  301. # Prepare tasks for all collections and queries
  302. # Avoid running any tasks for collections that failed to fetch data (have assigned None)
  303. tasks = [
  304. (cn, q)
  305. for cn in collection_names
  306. if collection_results[cn] is not None
  307. for q in queries
  308. ]
  309. with ThreadPoolExecutor() as executor:
  310. future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
  311. task_results = [future.result() for future in future_results]
  312. for result, err in task_results:
  313. if err is not None:
  314. error = True
  315. elif result is not None:
  316. results.append(result)
  317. if error and not results:
  318. raise Exception(
  319. "Hybrid search failed for all collections. Using Non-hybrid search as fallback."
  320. )
  321. return merge_and_sort_query_results(results, k=k)
  322. def get_embedding_function(
  323. embedding_engine,
  324. embedding_model,
  325. embedding_function,
  326. url,
  327. key,
  328. embedding_batch_size,
  329. ):
  330. if embedding_engine == "":
  331. return lambda query, prefix=None, user=None: embedding_function.encode(
  332. query, **({"prompt": prefix} if prefix else {})
  333. ).tolist()
  334. elif embedding_engine in ["ollama", "openai"]:
  335. func = lambda query, prefix=None, user=None: generate_embeddings(
  336. engine=embedding_engine,
  337. model=embedding_model,
  338. text=query,
  339. prefix=prefix,
  340. url=url,
  341. key=key,
  342. user=user,
  343. )
  344. def generate_multiple(query, prefix, user, func):
  345. if isinstance(query, list):
  346. embeddings = []
  347. for i in range(0, len(query), embedding_batch_size):
  348. embeddings.extend(
  349. func(
  350. query[i : i + embedding_batch_size],
  351. prefix=prefix,
  352. user=user,
  353. )
  354. )
  355. return embeddings
  356. else:
  357. return func(query, prefix, user)
  358. return lambda query, prefix=None, user=None: generate_multiple(
  359. query, prefix, user, func
  360. )
  361. else:
  362. raise ValueError(f"Unknown embedding engine: {embedding_engine}")
  363. def get_sources_from_files(
  364. request,
  365. files,
  366. queries,
  367. embedding_function,
  368. k,
  369. reranking_function,
  370. k_reranker,
  371. r,
  372. hybrid_search,
  373. full_context=False,
  374. ):
  375. log.debug(
  376. f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
  377. )
  378. extracted_collections = []
  379. relevant_contexts = []
  380. for file in files:
  381. context = None
  382. if file.get("docs"):
  383. # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
  384. context = {
  385. "documents": [[doc.get("content") for doc in file.get("docs")]],
  386. "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
  387. }
  388. elif file.get("context") == "full":
  389. # Manual Full Mode Toggle
  390. context = {
  391. "documents": [[file.get("file").get("data", {}).get("content")]],
  392. "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
  393. }
  394. elif (
  395. file.get("type") != "web_search"
  396. and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
  397. ):
  398. # BYPASS_EMBEDDING_AND_RETRIEVAL
  399. if file.get("type") == "collection":
  400. file_ids = file.get("data", {}).get("file_ids", [])
  401. documents = []
  402. metadatas = []
  403. for file_id in file_ids:
  404. file_object = Files.get_file_by_id(file_id)
  405. if file_object:
  406. documents.append(file_object.data.get("content", ""))
  407. metadatas.append(
  408. {
  409. "file_id": file_id,
  410. "name": file_object.filename,
  411. "source": file_object.filename,
  412. }
  413. )
  414. context = {
  415. "documents": [documents],
  416. "metadatas": [metadatas],
  417. }
  418. elif file.get("id"):
  419. file_object = Files.get_file_by_id(file.get("id"))
  420. if file_object:
  421. context = {
  422. "documents": [[file_object.data.get("content", "")]],
  423. "metadatas": [
  424. [
  425. {
  426. "file_id": file.get("id"),
  427. "name": file_object.filename,
  428. "source": file_object.filename,
  429. }
  430. ]
  431. ],
  432. }
  433. elif file.get("file").get("data"):
  434. context = {
  435. "documents": [[file.get("file").get("data", {}).get("content")]],
  436. "metadatas": [
  437. [file.get("file").get("data", {}).get("metadata", {})]
  438. ],
  439. }
  440. else:
  441. collection_names = []
  442. if file.get("type") == "collection":
  443. if file.get("legacy"):
  444. collection_names = file.get("collection_names", [])
  445. else:
  446. collection_names.append(file["id"])
  447. elif file.get("collection_name"):
  448. collection_names.append(file["collection_name"])
  449. elif file.get("id"):
  450. if file.get("legacy"):
  451. collection_names.append(f"{file['id']}")
  452. else:
  453. collection_names.append(f"file-{file['id']}")
  454. collection_names = set(collection_names).difference(extracted_collections)
  455. if not collection_names:
  456. log.debug(f"skipping {file} as it has already been extracted")
  457. continue
  458. if full_context:
  459. try:
  460. context = get_all_items_from_collections(collection_names)
  461. except Exception as e:
  462. log.exception(e)
  463. else:
  464. try:
  465. context = None
  466. if file.get("type") == "text":
  467. context = file["content"]
  468. else:
  469. if hybrid_search:
  470. try:
  471. context = query_collection_with_hybrid_search(
  472. collection_names=collection_names,
  473. queries=queries,
  474. embedding_function=embedding_function,
  475. k=k,
  476. reranking_function=reranking_function,
  477. k_reranker=k_reranker,
  478. r=r,
  479. )
  480. except Exception as e:
  481. log.debug(
  482. "Error when using hybrid search, using"
  483. " non hybrid search as fallback."
  484. )
  485. if (not hybrid_search) or (context is None):
  486. context = query_collection(
  487. collection_names=collection_names,
  488. queries=queries,
  489. embedding_function=embedding_function,
  490. k=k,
  491. )
  492. except Exception as e:
  493. log.exception(e)
  494. extracted_collections.extend(collection_names)
  495. if context:
  496. if "data" in file:
  497. del file["data"]
  498. relevant_contexts.append({**context, "file": file})
  499. sources = []
  500. for context in relevant_contexts:
  501. try:
  502. if "documents" in context:
  503. if "metadatas" in context:
  504. source = {
  505. "source": context["file"],
  506. "document": context["documents"][0],
  507. "metadata": context["metadatas"][0],
  508. }
  509. if "distances" in context and context["distances"]:
  510. source["distances"] = context["distances"][0]
  511. sources.append(source)
  512. except Exception as e:
  513. log.exception(e)
  514. return sources
  515. def get_model_path(model: str, update_model: bool = False):
  516. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  517. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  518. local_files_only = not update_model
  519. if OFFLINE_MODE:
  520. local_files_only = True
  521. snapshot_kwargs = {
  522. "cache_dir": cache_dir,
  523. "local_files_only": local_files_only,
  524. }
  525. log.debug(f"model: {model}")
  526. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  527. # Inspiration from upstream sentence_transformers
  528. if (
  529. os.path.exists(model)
  530. or ("\\" in model or model.count("/") > 1)
  531. and local_files_only
  532. ):
  533. # If fully qualified path exists, return input, else set repo_id
  534. return model
  535. elif "/" not in model:
  536. # Set valid repo_id for model short-name
  537. model = "sentence-transformers" + "/" + model
  538. snapshot_kwargs["repo_id"] = model
  539. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  540. try:
  541. model_repo_path = snapshot_download(**snapshot_kwargs)
  542. log.debug(f"model_repo_path: {model_repo_path}")
  543. return model_repo_path
  544. except Exception as e:
  545. log.exception(f"Cannot determine model snapshot path: {e}")
  546. return model
  547. def generate_openai_batch_embeddings(
  548. model: str,
  549. texts: list[str],
  550. url: str = "https://api.openai.com/v1",
  551. key: str = "",
  552. prefix: str = None,
  553. user: UserModel = None,
  554. ) -> Optional[list[list[float]]]:
  555. try:
  556. log.debug(
  557. f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
  558. )
  559. json_data = {"input": texts, "model": model}
  560. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  561. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  562. r = requests.post(
  563. f"{url}/embeddings",
  564. headers={
  565. "Content-Type": "application/json",
  566. "Authorization": f"Bearer {key}",
  567. **(
  568. {
  569. "X-OpenWebUI-User-Name": user.name,
  570. "X-OpenWebUI-User-Id": user.id,
  571. "X-OpenWebUI-User-Email": user.email,
  572. "X-OpenWebUI-User-Role": user.role,
  573. }
  574. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  575. else {}
  576. ),
  577. },
  578. json=json_data,
  579. )
  580. r.raise_for_status()
  581. data = r.json()
  582. if "data" in data:
  583. return [elem["embedding"] for elem in data["data"]]
  584. else:
  585. raise "Something went wrong :/"
  586. except Exception as e:
  587. log.exception(f"Error generating openai batch embeddings: {e}")
  588. return None
  589. def generate_ollama_batch_embeddings(
  590. model: str,
  591. texts: list[str],
  592. url: str,
  593. key: str = "",
  594. prefix: str = None,
  595. user: UserModel = None,
  596. ) -> Optional[list[list[float]]]:
  597. try:
  598. log.debug(
  599. f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
  600. )
  601. json_data = {"input": texts, "model": model}
  602. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  603. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  604. r = requests.post(
  605. f"{url}/api/embed",
  606. headers={
  607. "Content-Type": "application/json",
  608. "Authorization": f"Bearer {key}",
  609. **(
  610. {
  611. "X-OpenWebUI-User-Name": user.name,
  612. "X-OpenWebUI-User-Id": user.id,
  613. "X-OpenWebUI-User-Email": user.email,
  614. "X-OpenWebUI-User-Role": user.role,
  615. }
  616. if ENABLE_FORWARD_USER_INFO_HEADERS
  617. else {}
  618. ),
  619. },
  620. json=json_data,
  621. )
  622. r.raise_for_status()
  623. data = r.json()
  624. if "embeddings" in data:
  625. return data["embeddings"]
  626. else:
  627. raise "Something went wrong :/"
  628. except Exception as e:
  629. log.exception(f"Error generating ollama batch embeddings: {e}")
  630. return None
  631. def generate_embeddings(
  632. engine: str,
  633. model: str,
  634. text: Union[str, list[str]],
  635. prefix: Union[str, None] = None,
  636. **kwargs,
  637. ):
  638. url = kwargs.get("url", "")
  639. key = kwargs.get("key", "")
  640. user = kwargs.get("user")
  641. if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
  642. if isinstance(text, list):
  643. text = [f"{prefix}{text_element}" for text_element in text]
  644. else:
  645. text = f"{prefix}{text}"
  646. if engine == "ollama":
  647. if isinstance(text, list):
  648. embeddings = generate_ollama_batch_embeddings(
  649. **{
  650. "model": model,
  651. "texts": text,
  652. "url": url,
  653. "key": key,
  654. "prefix": prefix,
  655. "user": user,
  656. }
  657. )
  658. else:
  659. embeddings = generate_ollama_batch_embeddings(
  660. **{
  661. "model": model,
  662. "texts": [text],
  663. "url": url,
  664. "key": key,
  665. "prefix": prefix,
  666. "user": user,
  667. }
  668. )
  669. return embeddings[0] if isinstance(text, str) else embeddings
  670. elif engine == "openai":
  671. if isinstance(text, list):
  672. embeddings = generate_openai_batch_embeddings(
  673. model, text, url, key, prefix, user
  674. )
  675. else:
  676. embeddings = generate_openai_batch_embeddings(
  677. model, [text], url, key, prefix, user
  678. )
  679. return embeddings[0] if isinstance(text, str) else embeddings
  680. import operator
  681. from typing import Optional, Sequence
  682. from langchain_core.callbacks import Callbacks
  683. from langchain_core.documents import BaseDocumentCompressor, Document
  684. class RerankCompressor(BaseDocumentCompressor):
  685. embedding_function: Any
  686. top_n: int
  687. reranking_function: Any
  688. r_score: float
  689. class Config:
  690. extra = "forbid"
  691. arbitrary_types_allowed = True
  692. def compress_documents(
  693. self,
  694. documents: Sequence[Document],
  695. query: str,
  696. callbacks: Optional[Callbacks] = None,
  697. ) -> Sequence[Document]:
  698. reranking = self.reranking_function is not None
  699. if reranking:
  700. scores = self.reranking_function.predict(
  701. [(query, doc.page_content) for doc in documents]
  702. )
  703. else:
  704. from sentence_transformers import util
  705. query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
  706. document_embedding = self.embedding_function(
  707. [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
  708. )
  709. scores = util.cos_sim(query_embedding, document_embedding)[0]
  710. docs_with_scores = list(
  711. zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
  712. )
  713. if self.r_score:
  714. docs_with_scores = [
  715. (d, s) for d, s in docs_with_scores if s >= self.r_score
  716. ]
  717. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
  718. final_results = []
  719. for doc, doc_score in result[: self.top_n]:
  720. metadata = doc.metadata
  721. metadata["score"] = doc_score
  722. doc = Document(
  723. page_content=doc.page_content,
  724. metadata=metadata,
  725. )
  726. final_results.append(doc)
  727. return final_results