utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. import requests
  5. import hashlib
  6. from concurrent.futures import ThreadPoolExecutor
  7. import time
  8. from urllib.parse import quote
  9. from huggingface_hub import snapshot_download
  10. from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
  11. from langchain_community.retrievers import BM25Retriever
  12. from langchain_core.documents import Document
  13. from open_webui.config import VECTOR_DB
  14. from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
  15. from open_webui.models.users import UserModel
  16. from open_webui.models.files import Files
  17. from open_webui.models.notes import Notes
  18. from open_webui.retrieval.vector.main import GetResult
  19. from open_webui.env import (
  20. SRC_LOG_LEVELS,
  21. OFFLINE_MODE,
  22. ENABLE_FORWARD_USER_INFO_HEADERS,
  23. )
  24. from open_webui.config import (
  25. RAG_EMBEDDING_QUERY_PREFIX,
  26. RAG_EMBEDDING_CONTENT_PREFIX,
  27. RAG_EMBEDDING_PREFIX_FIELD_NAME,
  28. )
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["RAG"])
  31. from typing import Any
  32. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  33. from langchain_core.retrievers import BaseRetriever
  34. class VectorSearchRetriever(BaseRetriever):
  35. collection_name: Any
  36. embedding_function: Any
  37. top_k: int
  38. def _get_relevant_documents(
  39. self,
  40. query: str,
  41. *,
  42. run_manager: CallbackManagerForRetrieverRun,
  43. ) -> list[Document]:
  44. result = VECTOR_DB_CLIENT.search(
  45. collection_name=self.collection_name,
  46. vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
  47. limit=self.top_k,
  48. )
  49. ids = result.ids[0]
  50. metadatas = result.metadatas[0]
  51. documents = result.documents[0]
  52. results = []
  53. for idx in range(len(ids)):
  54. results.append(
  55. Document(
  56. metadata=metadatas[idx],
  57. page_content=documents[idx],
  58. )
  59. )
  60. return results
  61. def query_doc(
  62. collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
  63. ):
  64. try:
  65. log.debug(f"query_doc:doc {collection_name}")
  66. result = VECTOR_DB_CLIENT.search(
  67. collection_name=collection_name,
  68. vectors=[query_embedding],
  69. limit=k,
  70. )
  71. if result:
  72. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  73. return result
  74. except Exception as e:
  75. log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
  76. raise e
  77. def get_doc(collection_name: str, user: UserModel = None):
  78. try:
  79. log.debug(f"get_doc:doc {collection_name}")
  80. result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
  81. if result:
  82. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  83. return result
  84. except Exception as e:
  85. log.exception(f"Error getting doc {collection_name}: {e}")
  86. raise e
  87. def query_doc_with_hybrid_search(
  88. collection_name: str,
  89. collection_result: GetResult,
  90. query: str,
  91. embedding_function,
  92. k: int,
  93. reranking_function,
  94. k_reranker: int,
  95. r: float,
  96. hybrid_bm25_weight: float,
  97. ) -> dict:
  98. try:
  99. log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
  100. bm25_retriever = BM25Retriever.from_texts(
  101. texts=collection_result.documents[0],
  102. metadatas=collection_result.metadatas[0],
  103. )
  104. bm25_retriever.k = k
  105. vector_search_retriever = VectorSearchRetriever(
  106. collection_name=collection_name,
  107. embedding_function=embedding_function,
  108. top_k=k,
  109. )
  110. if hybrid_bm25_weight <= 0:
  111. ensemble_retriever = EnsembleRetriever(
  112. retrievers=[vector_search_retriever], weights=[1.0]
  113. )
  114. elif hybrid_bm25_weight >= 1:
  115. ensemble_retriever = EnsembleRetriever(
  116. retrievers=[bm25_retriever], weights=[1.0]
  117. )
  118. else:
  119. ensemble_retriever = EnsembleRetriever(
  120. retrievers=[bm25_retriever, vector_search_retriever],
  121. weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
  122. )
  123. compressor = RerankCompressor(
  124. embedding_function=embedding_function,
  125. top_n=k_reranker,
  126. reranking_function=reranking_function,
  127. r_score=r,
  128. )
  129. compression_retriever = ContextualCompressionRetriever(
  130. base_compressor=compressor, base_retriever=ensemble_retriever
  131. )
  132. result = compression_retriever.invoke(query)
  133. distances = [d.metadata.get("score") for d in result]
  134. documents = [d.page_content for d in result]
  135. metadatas = [d.metadata for d in result]
  136. # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
  137. if k < k_reranker:
  138. sorted_items = sorted(
  139. zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
  140. )
  141. sorted_items = sorted_items[:k]
  142. distances, documents, metadatas = map(list, zip(*sorted_items))
  143. result = {
  144. "distances": [distances],
  145. "documents": [documents],
  146. "metadatas": [metadatas],
  147. }
  148. log.info(
  149. "query_doc_with_hybrid_search:result "
  150. + f'{result["metadatas"]} {result["distances"]}'
  151. )
  152. return result
  153. except Exception as e:
  154. log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
  155. raise e
  156. def merge_get_results(get_results: list[dict]) -> dict:
  157. # Initialize lists to store combined data
  158. combined_documents = []
  159. combined_metadatas = []
  160. combined_ids = []
  161. for data in get_results:
  162. combined_documents.extend(data["documents"][0])
  163. combined_metadatas.extend(data["metadatas"][0])
  164. combined_ids.extend(data["ids"][0])
  165. # Create the output dictionary
  166. result = {
  167. "documents": [combined_documents],
  168. "metadatas": [combined_metadatas],
  169. "ids": [combined_ids],
  170. }
  171. return result
  172. def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
  173. # Initialize lists to store combined data
  174. combined = dict() # To store documents with unique document hashes
  175. for data in query_results:
  176. distances = data["distances"][0]
  177. documents = data["documents"][0]
  178. metadatas = data["metadatas"][0]
  179. for distance, document, metadata in zip(distances, documents, metadatas):
  180. if isinstance(document, str):
  181. doc_hash = hashlib.sha256(
  182. document.encode()
  183. ).hexdigest() # Compute a hash for uniqueness
  184. if doc_hash not in combined.keys():
  185. combined[doc_hash] = (distance, document, metadata)
  186. continue # if doc is new, no further comparison is needed
  187. # if doc is alredy in, but new distance is better, update
  188. if distance > combined[doc_hash][0]:
  189. combined[doc_hash] = (distance, document, metadata)
  190. combined = list(combined.values())
  191. # Sort the list based on distances
  192. combined.sort(key=lambda x: x[0], reverse=True)
  193. # Slice to keep only the top k elements
  194. sorted_distances, sorted_documents, sorted_metadatas = (
  195. zip(*combined[:k]) if combined else ([], [], [])
  196. )
  197. # Create and return the output dictionary
  198. return {
  199. "distances": [list(sorted_distances)],
  200. "documents": [list(sorted_documents)],
  201. "metadatas": [list(sorted_metadatas)],
  202. }
  203. def get_all_items_from_collections(collection_names: list[str]) -> dict:
  204. results = []
  205. for collection_name in collection_names:
  206. if collection_name:
  207. try:
  208. result = get_doc(collection_name=collection_name)
  209. if result is not None:
  210. results.append(result.model_dump())
  211. except Exception as e:
  212. log.exception(f"Error when querying the collection: {e}")
  213. else:
  214. pass
  215. return merge_get_results(results)
  216. def query_collection(
  217. collection_names: list[str],
  218. queries: list[str],
  219. embedding_function,
  220. k: int,
  221. ) -> dict:
  222. results = []
  223. error = False
  224. def process_query_collection(collection_name, query_embedding):
  225. try:
  226. if collection_name:
  227. result = query_doc(
  228. collection_name=collection_name,
  229. k=k,
  230. query_embedding=query_embedding,
  231. )
  232. if result is not None:
  233. return result.model_dump(), None
  234. return None, None
  235. except Exception as e:
  236. log.exception(f"Error when querying the collection: {e}")
  237. return None, e
  238. # Generate all query embeddings (in one call)
  239. query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
  240. log.debug(
  241. f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
  242. )
  243. with ThreadPoolExecutor() as executor:
  244. future_results = []
  245. for query_embedding in query_embeddings:
  246. for collection_name in collection_names:
  247. result = executor.submit(
  248. process_query_collection, collection_name, query_embedding
  249. )
  250. future_results.append(result)
  251. task_results = [future.result() for future in future_results]
  252. for result, err in task_results:
  253. if err is not None:
  254. error = True
  255. elif result is not None:
  256. results.append(result)
  257. if error and not results:
  258. log.warning("All collection queries failed. No results returned.")
  259. return merge_and_sort_query_results(results, k=k)
  260. def query_collection_with_hybrid_search(
  261. collection_names: list[str],
  262. queries: list[str],
  263. embedding_function,
  264. k: int,
  265. reranking_function,
  266. k_reranker: int,
  267. r: float,
  268. hybrid_bm25_weight: float,
  269. ) -> dict:
  270. results = []
  271. error = False
  272. # Fetch collection data once per collection sequentially
  273. # Avoid fetching the same data multiple times later
  274. collection_results = {}
  275. for collection_name in collection_names:
  276. try:
  277. log.debug(
  278. f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
  279. )
  280. collection_results[collection_name] = VECTOR_DB_CLIENT.get(
  281. collection_name=collection_name
  282. )
  283. except Exception as e:
  284. log.exception(f"Failed to fetch collection {collection_name}: {e}")
  285. collection_results[collection_name] = None
  286. log.info(
  287. f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
  288. )
  289. def process_query(collection_name, query):
  290. try:
  291. result = query_doc_with_hybrid_search(
  292. collection_name=collection_name,
  293. collection_result=collection_results[collection_name],
  294. query=query,
  295. embedding_function=embedding_function,
  296. k=k,
  297. reranking_function=reranking_function,
  298. k_reranker=k_reranker,
  299. r=r,
  300. hybrid_bm25_weight=hybrid_bm25_weight,
  301. )
  302. return result, None
  303. except Exception as e:
  304. log.exception(f"Error when querying the collection with hybrid_search: {e}")
  305. return None, e
  306. # Prepare tasks for all collections and queries
  307. # Avoid running any tasks for collections that failed to fetch data (have assigned None)
  308. tasks = [
  309. (cn, q)
  310. for cn in collection_names
  311. if collection_results[cn] is not None
  312. for q in queries
  313. ]
  314. with ThreadPoolExecutor() as executor:
  315. future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
  316. task_results = [future.result() for future in future_results]
  317. for result, err in task_results:
  318. if err is not None:
  319. error = True
  320. elif result is not None:
  321. results.append(result)
  322. if error and not results:
  323. raise Exception(
  324. "Hybrid search failed for all collections. Using Non-hybrid search as fallback."
  325. )
  326. return merge_and_sort_query_results(results, k=k)
  327. def get_embedding_function(
  328. embedding_engine,
  329. embedding_model,
  330. embedding_function,
  331. url,
  332. key,
  333. embedding_batch_size,
  334. azure_api_version=None,
  335. ):
  336. if embedding_engine == "":
  337. return lambda query, prefix=None, user=None: embedding_function.encode(
  338. query, **({"prompt": prefix} if prefix else {})
  339. ).tolist()
  340. elif embedding_engine in ["ollama", "openai", "azure_openai"]:
  341. func = lambda query, prefix=None, user=None: generate_embeddings(
  342. engine=embedding_engine,
  343. model=embedding_model,
  344. text=query,
  345. prefix=prefix,
  346. url=url,
  347. key=key,
  348. user=user,
  349. azure_api_version=azure_api_version,
  350. )
  351. def generate_multiple(query, prefix, user, func):
  352. if isinstance(query, list):
  353. embeddings = []
  354. for i in range(0, len(query), embedding_batch_size):
  355. embeddings.extend(
  356. func(
  357. query[i : i + embedding_batch_size],
  358. prefix=prefix,
  359. user=user,
  360. )
  361. )
  362. return embeddings
  363. else:
  364. return func(query, prefix, user)
  365. return lambda query, prefix=None, user=None: generate_multiple(
  366. query, prefix, user, func
  367. )
  368. else:
  369. raise ValueError(f"Unknown embedding engine: {embedding_engine}")
  370. def get_sources_from_files(
  371. request,
  372. files,
  373. queries,
  374. embedding_function,
  375. k,
  376. reranking_function,
  377. k_reranker,
  378. r,
  379. hybrid_bm25_weight,
  380. hybrid_search,
  381. full_context=False,
  382. ):
  383. log.debug(
  384. f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
  385. )
  386. extracted_collections = []
  387. query_results = []
  388. for file in files:
  389. query_result = None
  390. if file.get("docs"):
  391. # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
  392. query_result = {
  393. "documents": [[doc.get("content") for doc in file.get("docs")]],
  394. "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
  395. }
  396. elif file.get("type") == "text":
  397. # Text File
  398. query_result = {
  399. "documents": [[file.get("content")]],
  400. "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
  401. }
  402. elif file.get("type") == "note":
  403. # Note Attached
  404. note = Notes.get_note_by_id(file.get("id"))
  405. query_result = {
  406. "documents": [[note.data.get("content", {}).get("md", "")]],
  407. "metadatas": [[{"file_id": note.id, "name": note.title}]],
  408. }
  409. elif file.get("context") == "full":
  410. if file.get("type") == "file":
  411. # Manual Full Mode Toggle
  412. query_result = {
  413. "documents": [[file.get("file").get("data", {}).get("content")]],
  414. "metadatas": [
  415. [{"file_id": file.get("id"), "name": file.get("name")}]
  416. ],
  417. }
  418. elif file.get("type") == "collection":
  419. # Manual Full Mode Toggle for Collection
  420. file_ids = file.get("data", {}).get("file_ids", [])
  421. documents = []
  422. metadatas = []
  423. for file_id in file_ids:
  424. file_object = Files.get_file_by_id(file_id)
  425. if file_object:
  426. documents.append(file_object.data.get("content", ""))
  427. metadatas.append(
  428. {
  429. "file_id": file_id,
  430. "name": file_object.filename,
  431. "source": file_object.filename,
  432. }
  433. )
  434. query_result = {
  435. "documents": [documents],
  436. "metadatas": [metadatas],
  437. }
  438. elif (
  439. file.get("type") != "web_search"
  440. and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
  441. ):
  442. # BYPASS_EMBEDDING_AND_RETRIEVAL
  443. if file.get("type") == "collection":
  444. file_ids = file.get("data", {}).get("file_ids", [])
  445. documents = []
  446. metadatas = []
  447. for file_id in file_ids:
  448. file_object = Files.get_file_by_id(file_id)
  449. if file_object:
  450. documents.append(file_object.data.get("content", ""))
  451. metadatas.append(
  452. {
  453. "file_id": file_id,
  454. "name": file_object.filename,
  455. "source": file_object.filename,
  456. }
  457. )
  458. query_result = {
  459. "documents": [documents],
  460. "metadatas": [metadatas],
  461. }
  462. elif file.get("id"):
  463. file_object = Files.get_file_by_id(file.get("id"))
  464. if file_object:
  465. query_result = {
  466. "documents": [[file_object.data.get("content", "")]],
  467. "metadatas": [
  468. [
  469. {
  470. "file_id": file.get("id"),
  471. "name": file_object.filename,
  472. "source": file_object.filename,
  473. }
  474. ]
  475. ],
  476. }
  477. elif file.get("file").get("data"):
  478. query_result = {
  479. "documents": [[file.get("file").get("data", {}).get("content")]],
  480. "metadatas": [
  481. [file.get("file").get("data", {}).get("metadata", {})]
  482. ],
  483. }
  484. else:
  485. collection_names = []
  486. if file.get("type") == "collection":
  487. if file.get("legacy"):
  488. collection_names = file.get("collection_names", [])
  489. else:
  490. collection_names.append(file["id"])
  491. elif file.get("collection_name"):
  492. collection_names.append(file["collection_name"])
  493. elif file.get("id"):
  494. if file.get("legacy"):
  495. collection_names.append(f"{file['id']}")
  496. else:
  497. collection_names.append(f"file-{file['id']}")
  498. collection_names = set(collection_names).difference(extracted_collections)
  499. if not collection_names:
  500. log.debug(f"skipping {file} as it has already been extracted")
  501. continue
  502. if full_context:
  503. try:
  504. query_result = get_all_items_from_collections(collection_names)
  505. except Exception as e:
  506. log.exception(e)
  507. else:
  508. try:
  509. query_result = None
  510. if file.get("type") == "text":
  511. # Not sure when this is used, but it seems to be a fallback
  512. query_result = {
  513. "documents": [
  514. [file.get("file").get("data", {}).get("content")]
  515. ],
  516. "metadatas": [
  517. [file.get("file").get("data", {}).get("meta", {})]
  518. ],
  519. }
  520. else:
  521. if hybrid_search:
  522. try:
  523. query_result = query_collection_with_hybrid_search(
  524. collection_names=collection_names,
  525. queries=queries,
  526. embedding_function=embedding_function,
  527. k=k,
  528. reranking_function=reranking_function,
  529. k_reranker=k_reranker,
  530. r=r,
  531. hybrid_bm25_weight=hybrid_bm25_weight,
  532. )
  533. except Exception as e:
  534. log.debug(
  535. "Error when using hybrid search, using"
  536. " non hybrid search as fallback."
  537. )
  538. if (not hybrid_search) or (query_result is None):
  539. query_result = query_collection(
  540. collection_names=collection_names,
  541. queries=queries,
  542. embedding_function=embedding_function,
  543. k=k,
  544. )
  545. except Exception as e:
  546. log.exception(e)
  547. extracted_collections.extend(collection_names)
  548. if query_result:
  549. if "data" in file:
  550. del file["data"]
  551. query_results.append({**query_result, "file": file})
  552. sources = []
  553. for query_result in query_results:
  554. try:
  555. if "documents" in query_result:
  556. if "metadatas" in query_result:
  557. source = {
  558. "source": query_result["file"],
  559. "document": query_result["documents"][0],
  560. "metadata": query_result["metadatas"][0],
  561. }
  562. if "distances" in query_result and query_result["distances"]:
  563. source["distances"] = query_result["distances"][0]
  564. sources.append(source)
  565. except Exception as e:
  566. log.exception(e)
  567. return sources
  568. def get_model_path(model: str, update_model: bool = False):
  569. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  570. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  571. local_files_only = not update_model
  572. if OFFLINE_MODE:
  573. local_files_only = True
  574. snapshot_kwargs = {
  575. "cache_dir": cache_dir,
  576. "local_files_only": local_files_only,
  577. }
  578. log.debug(f"model: {model}")
  579. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  580. # Inspiration from upstream sentence_transformers
  581. if (
  582. os.path.exists(model)
  583. or ("\\" in model or model.count("/") > 1)
  584. and local_files_only
  585. ):
  586. # If fully qualified path exists, return input, else set repo_id
  587. return model
  588. elif "/" not in model:
  589. # Set valid repo_id for model short-name
  590. model = "sentence-transformers" + "/" + model
  591. snapshot_kwargs["repo_id"] = model
  592. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  593. try:
  594. model_repo_path = snapshot_download(**snapshot_kwargs)
  595. log.debug(f"model_repo_path: {model_repo_path}")
  596. return model_repo_path
  597. except Exception as e:
  598. log.exception(f"Cannot determine model snapshot path: {e}")
  599. return model
  600. def generate_openai_batch_embeddings(
  601. model: str,
  602. texts: list[str],
  603. url: str = "https://api.openai.com/v1",
  604. key: str = "",
  605. prefix: str = None,
  606. user: UserModel = None,
  607. ) -> Optional[list[list[float]]]:
  608. try:
  609. log.debug(
  610. f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
  611. )
  612. json_data = {"input": texts, "model": model}
  613. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  614. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  615. r = requests.post(
  616. f"{url}/embeddings",
  617. headers={
  618. "Content-Type": "application/json",
  619. "Authorization": f"Bearer {key}",
  620. **(
  621. {
  622. "X-OpenWebUI-User-Name": quote(user.name),
  623. "X-OpenWebUI-User-Id": quote(user.id),
  624. "X-OpenWebUI-User-Email": quote(user.email),
  625. "X-OpenWebUI-User-Role": quote(user.role),
  626. }
  627. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  628. else {}
  629. ),
  630. },
  631. json=json_data,
  632. )
  633. r.raise_for_status()
  634. data = r.json()
  635. if "data" in data:
  636. return [elem["embedding"] for elem in data["data"]]
  637. else:
  638. raise "Something went wrong :/"
  639. except Exception as e:
  640. log.exception(f"Error generating openai batch embeddings: {e}")
  641. return None
  642. def generate_azure_openai_batch_embeddings(
  643. model: str,
  644. texts: list[str],
  645. url: str,
  646. key: str = "",
  647. version: str = "",
  648. prefix: str = None,
  649. user: UserModel = None,
  650. ) -> Optional[list[list[float]]]:
  651. try:
  652. log.debug(
  653. f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
  654. )
  655. json_data = {"input": texts}
  656. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  657. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  658. url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
  659. for _ in range(5):
  660. r = requests.post(
  661. url,
  662. headers={
  663. "Content-Type": "application/json",
  664. "api-key": key,
  665. **(
  666. {
  667. "X-OpenWebUI-User-Name": quote(user.name),
  668. "X-OpenWebUI-User-Id": quote(user.id),
  669. "X-OpenWebUI-User-Email": quote(user.email),
  670. "X-OpenWebUI-User-Role": quote(user.role),
  671. }
  672. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  673. else {}
  674. ),
  675. },
  676. json=json_data,
  677. )
  678. if r.status_code == 429:
  679. retry = float(r.headers.get("Retry-After", "1"))
  680. time.sleep(retry)
  681. continue
  682. r.raise_for_status()
  683. data = r.json()
  684. if "data" in data:
  685. return [elem["embedding"] for elem in data["data"]]
  686. else:
  687. raise Exception("Something went wrong :/")
  688. return None
  689. except Exception as e:
  690. log.exception(f"Error generating azure openai batch embeddings: {e}")
  691. return None
  692. def generate_ollama_batch_embeddings(
  693. model: str,
  694. texts: list[str],
  695. url: str,
  696. key: str = "",
  697. prefix: str = None,
  698. user: UserModel = None,
  699. ) -> Optional[list[list[float]]]:
  700. try:
  701. log.debug(
  702. f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
  703. )
  704. json_data = {"input": texts, "model": model}
  705. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  706. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  707. r = requests.post(
  708. f"{url}/api/embed",
  709. headers={
  710. "Content-Type": "application/json",
  711. "Authorization": f"Bearer {key}",
  712. **(
  713. {
  714. "X-OpenWebUI-User-Name": quote(user.name),
  715. "X-OpenWebUI-User-Id": quote(user.id),
  716. "X-OpenWebUI-User-Email": quote(user.email),
  717. "X-OpenWebUI-User-Role": quote(user.role),
  718. }
  719. if ENABLE_FORWARD_USER_INFO_HEADERS
  720. else {}
  721. ),
  722. },
  723. json=json_data,
  724. )
  725. r.raise_for_status()
  726. data = r.json()
  727. if "embeddings" in data:
  728. return data["embeddings"]
  729. else:
  730. raise "Something went wrong :/"
  731. except Exception as e:
  732. log.exception(f"Error generating ollama batch embeddings: {e}")
  733. return None
  734. def generate_embeddings(
  735. engine: str,
  736. model: str,
  737. text: Union[str, list[str]],
  738. prefix: Union[str, None] = None,
  739. **kwargs,
  740. ):
  741. url = kwargs.get("url", "")
  742. key = kwargs.get("key", "")
  743. user = kwargs.get("user")
  744. if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
  745. if isinstance(text, list):
  746. text = [f"{prefix}{text_element}" for text_element in text]
  747. else:
  748. text = f"{prefix}{text}"
  749. if engine == "ollama":
  750. embeddings = generate_ollama_batch_embeddings(
  751. **{
  752. "model": model,
  753. "texts": text if isinstance(text, list) else [text],
  754. "url": url,
  755. "key": key,
  756. "prefix": prefix,
  757. "user": user,
  758. }
  759. )
  760. return embeddings[0] if isinstance(text, str) else embeddings
  761. elif engine == "openai":
  762. embeddings = generate_openai_batch_embeddings(
  763. model, text if isinstance(text, list) else [text], url, key, prefix, user
  764. )
  765. return embeddings[0] if isinstance(text, str) else embeddings
  766. elif engine == "azure_openai":
  767. azure_api_version = kwargs.get("azure_api_version", "")
  768. embeddings = generate_azure_openai_batch_embeddings(
  769. model,
  770. text if isinstance(text, list) else [text],
  771. url,
  772. key,
  773. azure_api_version,
  774. prefix,
  775. user,
  776. )
  777. return embeddings[0] if isinstance(text, str) else embeddings
  778. import operator
  779. from typing import Optional, Sequence
  780. from langchain_core.callbacks import Callbacks
  781. from langchain_core.documents import BaseDocumentCompressor, Document
  782. class RerankCompressor(BaseDocumentCompressor):
  783. embedding_function: Any
  784. top_n: int
  785. reranking_function: Any
  786. r_score: float
  787. class Config:
  788. extra = "forbid"
  789. arbitrary_types_allowed = True
  790. def compress_documents(
  791. self,
  792. documents: Sequence[Document],
  793. query: str,
  794. callbacks: Optional[Callbacks] = None,
  795. ) -> Sequence[Document]:
  796. reranking = self.reranking_function is not None
  797. if reranking:
  798. scores = self.reranking_function.predict(
  799. [(query, doc.page_content) for doc in documents]
  800. )
  801. else:
  802. from sentence_transformers import util
  803. query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
  804. document_embedding = self.embedding_function(
  805. [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
  806. )
  807. scores = util.cos_sim(query_embedding, document_embedding)[0]
  808. docs_with_scores = list(
  809. zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
  810. )
  811. if self.r_score:
  812. docs_with_scores = [
  813. (d, s) for d, s in docs_with_scores if s >= self.r_score
  814. ]
  815. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
  816. final_results = []
  817. for doc, doc_score in result[: self.top_n]:
  818. metadata = doc.metadata
  819. metadata["score"] = doc_score
  820. doc = Document(
  821. page_content=doc.page_content,
  822. metadata=metadata,
  823. )
  824. final_results.append(doc)
  825. return final_results