utils.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. import requests
  5. import hashlib
  6. from concurrent.futures import ThreadPoolExecutor
  7. import time
  8. import re
  9. from urllib.parse import quote
  10. from huggingface_hub import snapshot_download
  11. from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
  12. from langchain_community.retrievers import BM25Retriever
  13. from langchain_core.documents import Document
  14. from open_webui.config import VECTOR_DB
  15. from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
  16. from open_webui.models.users import UserModel
  17. from open_webui.models.files import Files
  18. from open_webui.models.knowledge import Knowledges
  19. from open_webui.models.chats import Chats
  20. from open_webui.models.notes import Notes
  21. from open_webui.retrieval.vector.main import GetResult
  22. from open_webui.utils.access_control import has_access
  23. from open_webui.utils.misc import get_message_list
  24. from open_webui.retrieval.web.utils import get_web_loader
  25. from open_webui.retrieval.loaders.youtube import YoutubeLoader
  26. from open_webui.env import (
  27. SRC_LOG_LEVELS,
  28. OFFLINE_MODE,
  29. ENABLE_FORWARD_USER_INFO_HEADERS,
  30. )
  31. from open_webui.config import (
  32. RAG_EMBEDDING_QUERY_PREFIX,
  33. RAG_EMBEDDING_CONTENT_PREFIX,
  34. RAG_EMBEDDING_PREFIX_FIELD_NAME,
  35. )
  36. log = logging.getLogger(__name__)
  37. log.setLevel(SRC_LOG_LEVELS["RAG"])
  38. from typing import Any
  39. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  40. from langchain_core.retrievers import BaseRetriever
  41. def is_youtube_url(url: str) -> bool:
  42. youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
  43. return re.match(youtube_regex, url) is not None
  44. def get_loader(request, url: str):
  45. if is_youtube_url(url):
  46. return YoutubeLoader(
  47. url,
  48. language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
  49. proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
  50. )
  51. else:
  52. return get_web_loader(
  53. url,
  54. verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
  55. requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
  56. trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
  57. )
  58. def get_content_from_url(request, url: str) -> str:
  59. loader = get_loader(request, url)
  60. docs = loader.load()
  61. content = " ".join([doc.page_content for doc in docs])
  62. return content, docs
  63. class VectorSearchRetriever(BaseRetriever):
  64. collection_name: Any
  65. embedding_function: Any
  66. top_k: int
  67. def _get_relevant_documents(
  68. self,
  69. query: str,
  70. *,
  71. run_manager: CallbackManagerForRetrieverRun,
  72. ) -> list[Document]:
  73. result = VECTOR_DB_CLIENT.search(
  74. collection_name=self.collection_name,
  75. vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
  76. limit=self.top_k,
  77. )
  78. ids = result.ids[0]
  79. metadatas = result.metadatas[0]
  80. documents = result.documents[0]
  81. results = []
  82. for idx in range(len(ids)):
  83. results.append(
  84. Document(
  85. metadata=metadatas[idx],
  86. page_content=documents[idx],
  87. )
  88. )
  89. return results
  90. def query_doc(
  91. collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
  92. ):
  93. try:
  94. log.debug(f"query_doc:doc {collection_name}")
  95. result = VECTOR_DB_CLIENT.search(
  96. collection_name=collection_name,
  97. vectors=[query_embedding],
  98. limit=k,
  99. )
  100. if result:
  101. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  102. return result
  103. except Exception as e:
  104. log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
  105. raise e
  106. def get_doc(collection_name: str, user: UserModel = None):
  107. try:
  108. log.debug(f"get_doc:doc {collection_name}")
  109. result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
  110. if result:
  111. log.info(f"query_doc:result {result.ids} {result.metadatas}")
  112. return result
  113. except Exception as e:
  114. log.exception(f"Error getting doc {collection_name}: {e}")
  115. raise e
  116. def query_doc_with_hybrid_search(
  117. collection_name: str,
  118. collection_result: GetResult,
  119. query: str,
  120. embedding_function,
  121. k: int,
  122. reranking_function,
  123. k_reranker: int,
  124. r: float,
  125. hybrid_bm25_weight: float,
  126. ) -> dict:
  127. try:
  128. # First check if collection_result has the required attributes
  129. if (
  130. not collection_result
  131. or not hasattr(collection_result, "documents")
  132. or not hasattr(collection_result, "metadatas")
  133. ):
  134. log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
  135. return {"documents": [], "metadatas": [], "distances": []}
  136. # Now safely check the documents content after confirming attributes exist
  137. if (
  138. not collection_result.documents
  139. or len(collection_result.documents) == 0
  140. or not collection_result.documents[0]
  141. ):
  142. log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
  143. return {"documents": [], "metadatas": [], "distances": []}
  144. log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
  145. bm25_retriever = BM25Retriever.from_texts(
  146. texts=collection_result.documents[0],
  147. metadatas=collection_result.metadatas[0],
  148. )
  149. bm25_retriever.k = k
  150. vector_search_retriever = VectorSearchRetriever(
  151. collection_name=collection_name,
  152. embedding_function=embedding_function,
  153. top_k=k,
  154. )
  155. if hybrid_bm25_weight <= 0:
  156. ensemble_retriever = EnsembleRetriever(
  157. retrievers=[vector_search_retriever], weights=[1.0]
  158. )
  159. elif hybrid_bm25_weight >= 1:
  160. ensemble_retriever = EnsembleRetriever(
  161. retrievers=[bm25_retriever], weights=[1.0]
  162. )
  163. else:
  164. ensemble_retriever = EnsembleRetriever(
  165. retrievers=[bm25_retriever, vector_search_retriever],
  166. weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
  167. )
  168. compressor = RerankCompressor(
  169. embedding_function=embedding_function,
  170. top_n=k_reranker,
  171. reranking_function=reranking_function,
  172. r_score=r,
  173. )
  174. compression_retriever = ContextualCompressionRetriever(
  175. base_compressor=compressor, base_retriever=ensemble_retriever
  176. )
  177. result = compression_retriever.invoke(query)
  178. distances = [d.metadata.get("score") for d in result]
  179. documents = [d.page_content for d in result]
  180. metadatas = [d.metadata for d in result]
  181. # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
  182. if k < k_reranker:
  183. sorted_items = sorted(
  184. zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
  185. )
  186. sorted_items = sorted_items[:k]
  187. if sorted_items:
  188. distances, documents, metadatas = map(list, zip(*sorted_items))
  189. else:
  190. distances, documents, metadatas = [], [], []
  191. result = {
  192. "distances": [distances],
  193. "documents": [documents],
  194. "metadatas": [metadatas],
  195. }
  196. log.info(
  197. "query_doc_with_hybrid_search:result "
  198. + f'{result["metadatas"]} {result["distances"]}'
  199. )
  200. return result
  201. except Exception as e:
  202. log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
  203. raise e
  204. def merge_get_results(get_results: list[dict]) -> dict:
  205. # Initialize lists to store combined data
  206. combined_documents = []
  207. combined_metadatas = []
  208. combined_ids = []
  209. for data in get_results:
  210. combined_documents.extend(data["documents"][0])
  211. combined_metadatas.extend(data["metadatas"][0])
  212. combined_ids.extend(data["ids"][0])
  213. # Create the output dictionary
  214. result = {
  215. "documents": [combined_documents],
  216. "metadatas": [combined_metadatas],
  217. "ids": [combined_ids],
  218. }
  219. return result
  220. def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
  221. # Initialize lists to store combined data
  222. combined = dict() # To store documents with unique document hashes
  223. for data in query_results:
  224. if (
  225. len(data.get("distances", [])) == 0
  226. or len(data.get("documents", [])) == 0
  227. or len(data.get("metadatas", [])) == 0
  228. ):
  229. continue
  230. distances = data["distances"][0]
  231. documents = data["documents"][0]
  232. metadatas = data["metadatas"][0]
  233. for distance, document, metadata in zip(distances, documents, metadatas):
  234. if isinstance(document, str):
  235. doc_hash = hashlib.sha256(
  236. document.encode()
  237. ).hexdigest() # Compute a hash for uniqueness
  238. if doc_hash not in combined.keys():
  239. combined[doc_hash] = (distance, document, metadata)
  240. continue # if doc is new, no further comparison is needed
  241. # if doc is alredy in, but new distance is better, update
  242. if distance > combined[doc_hash][0]:
  243. combined[doc_hash] = (distance, document, metadata)
  244. combined = list(combined.values())
  245. # Sort the list based on distances
  246. combined.sort(key=lambda x: x[0], reverse=True)
  247. # Slice to keep only the top k elements
  248. sorted_distances, sorted_documents, sorted_metadatas = (
  249. zip(*combined[:k]) if combined else ([], [], [])
  250. )
  251. # Create and return the output dictionary
  252. return {
  253. "distances": [list(sorted_distances)],
  254. "documents": [list(sorted_documents)],
  255. "metadatas": [list(sorted_metadatas)],
  256. }
  257. def get_all_items_from_collections(collection_names: list[str]) -> dict:
  258. results = []
  259. for collection_name in collection_names:
  260. if collection_name:
  261. try:
  262. result = get_doc(collection_name=collection_name)
  263. if result is not None:
  264. results.append(result.model_dump())
  265. except Exception as e:
  266. log.exception(f"Error when querying the collection: {e}")
  267. else:
  268. pass
  269. return merge_get_results(results)
  270. def query_collection(
  271. collection_names: list[str],
  272. queries: list[str],
  273. embedding_function,
  274. k: int,
  275. ) -> dict:
  276. results = []
  277. error = False
  278. def process_query_collection(collection_name, query_embedding):
  279. try:
  280. if collection_name:
  281. result = query_doc(
  282. collection_name=collection_name,
  283. k=k,
  284. query_embedding=query_embedding,
  285. )
  286. if result is not None:
  287. return result.model_dump(), None
  288. return None, None
  289. except Exception as e:
  290. log.exception(f"Error when querying the collection: {e}")
  291. return None, e
  292. # Generate all query embeddings (in one call)
  293. query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
  294. log.debug(
  295. f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
  296. )
  297. with ThreadPoolExecutor() as executor:
  298. future_results = []
  299. for query_embedding in query_embeddings:
  300. for collection_name in collection_names:
  301. result = executor.submit(
  302. process_query_collection, collection_name, query_embedding
  303. )
  304. future_results.append(result)
  305. task_results = [future.result() for future in future_results]
  306. for result, err in task_results:
  307. if err is not None:
  308. error = True
  309. elif result is not None:
  310. results.append(result)
  311. if error and not results:
  312. log.warning("All collection queries failed. No results returned.")
  313. return merge_and_sort_query_results(results, k=k)
  314. def query_collection_with_hybrid_search(
  315. collection_names: list[str],
  316. queries: list[str],
  317. embedding_function,
  318. k: int,
  319. reranking_function,
  320. k_reranker: int,
  321. r: float,
  322. hybrid_bm25_weight: float,
  323. ) -> dict:
  324. results = []
  325. error = False
  326. # Fetch collection data once per collection sequentially
  327. # Avoid fetching the same data multiple times later
  328. collection_results = {}
  329. for collection_name in collection_names:
  330. try:
  331. log.debug(
  332. f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
  333. )
  334. collection_results[collection_name] = VECTOR_DB_CLIENT.get(
  335. collection_name=collection_name
  336. )
  337. except Exception as e:
  338. log.exception(f"Failed to fetch collection {collection_name}: {e}")
  339. collection_results[collection_name] = None
  340. log.info(
  341. f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
  342. )
  343. def process_query(collection_name, query):
  344. try:
  345. result = query_doc_with_hybrid_search(
  346. collection_name=collection_name,
  347. collection_result=collection_results[collection_name],
  348. query=query,
  349. embedding_function=embedding_function,
  350. k=k,
  351. reranking_function=reranking_function,
  352. k_reranker=k_reranker,
  353. r=r,
  354. hybrid_bm25_weight=hybrid_bm25_weight,
  355. )
  356. return result, None
  357. except Exception as e:
  358. log.exception(f"Error when querying the collection with hybrid_search: {e}")
  359. return None, e
  360. # Prepare tasks for all collections and queries
  361. # Avoid running any tasks for collections that failed to fetch data (have assigned None)
  362. tasks = [
  363. (cn, q)
  364. for cn in collection_names
  365. if collection_results[cn] is not None
  366. for q in queries
  367. ]
  368. with ThreadPoolExecutor() as executor:
  369. future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
  370. task_results = [future.result() for future in future_results]
  371. for result, err in task_results:
  372. if err is not None:
  373. error = True
  374. elif result is not None:
  375. results.append(result)
  376. if error and not results:
  377. raise Exception(
  378. "Hybrid search failed for all collections. Using Non-hybrid search as fallback."
  379. )
  380. return merge_and_sort_query_results(results, k=k)
  381. def get_embedding_function(
  382. embedding_engine,
  383. embedding_model,
  384. embedding_function,
  385. url,
  386. key,
  387. embedding_batch_size,
  388. azure_api_version=None,
  389. ):
  390. if embedding_engine == "":
  391. return lambda query, prefix=None, user=None: embedding_function.encode(
  392. query, **({"prompt": prefix} if prefix else {})
  393. ).tolist()
  394. elif embedding_engine in ["ollama", "openai", "azure_openai"]:
  395. func = lambda query, prefix=None, user=None: generate_embeddings(
  396. engine=embedding_engine,
  397. model=embedding_model,
  398. text=query,
  399. prefix=prefix,
  400. url=url,
  401. key=key,
  402. user=user,
  403. azure_api_version=azure_api_version,
  404. )
  405. def generate_multiple(query, prefix, user, func):
  406. if isinstance(query, list):
  407. embeddings = []
  408. for i in range(0, len(query), embedding_batch_size):
  409. batch_embeddings = func(
  410. query[i : i + embedding_batch_size],
  411. prefix=prefix,
  412. user=user,
  413. )
  414. if isinstance(batch_embeddings, list):
  415. embeddings.extend(batch_embeddings)
  416. return embeddings
  417. else:
  418. return func(query, prefix, user)
  419. return lambda query, prefix=None, user=None: generate_multiple(
  420. query, prefix, user, func
  421. )
  422. else:
  423. raise ValueError(f"Unknown embedding engine: {embedding_engine}")
  424. def get_reranking_function(reranking_engine, reranking_model, reranking_function):
  425. if reranking_function is None:
  426. return None
  427. if reranking_engine == "external":
  428. return lambda query, documents, user=None: reranking_function.predict(
  429. [(query, doc.page_content) for doc in documents], user=user
  430. )
  431. else:
  432. return lambda query, documents, user=None: reranking_function.predict(
  433. [(query, doc.page_content) for doc in documents]
  434. )
  435. def get_sources_from_items(
  436. request,
  437. items,
  438. queries,
  439. embedding_function,
  440. k,
  441. reranking_function,
  442. k_reranker,
  443. r,
  444. hybrid_bm25_weight,
  445. hybrid_search,
  446. full_context=False,
  447. user: Optional[UserModel] = None,
  448. ):
  449. log.debug(
  450. f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
  451. )
  452. extracted_collections = []
  453. query_results = []
  454. for item in items:
  455. query_result = None
  456. collection_names = []
  457. if item.get("type") == "text":
  458. # Raw Text
  459. # Used during temporary chat file uploads or web page & youtube attachements
  460. if item.get("context") == "full":
  461. if item.get("file"):
  462. # if item has file data, use it
  463. query_result = {
  464. "documents": [
  465. [item.get("file", {}).get("data", {}).get("content")]
  466. ],
  467. "metadatas": [[item.get("file", {}).get("meta", {})]],
  468. }
  469. if query_result is None:
  470. # Fallback
  471. if item.get("collection_name"):
  472. # If item has a collection name, use it
  473. collection_names.append(item.get("collection_name"))
  474. elif item.get("file"):
  475. # If item has file data, use it
  476. query_result = {
  477. "documents": [
  478. [item.get("file", {}).get("data", {}).get("content")]
  479. ],
  480. "metadatas": [[item.get("file", {}).get("meta", {})]],
  481. }
  482. else:
  483. # Fallback to item content
  484. query_result = {
  485. "documents": [[item.get("content")]],
  486. "metadatas": [
  487. [{"file_id": item.get("id"), "name": item.get("name")}]
  488. ],
  489. }
  490. elif item.get("type") == "note":
  491. # Note Attached
  492. note = Notes.get_note_by_id(item.get("id"))
  493. if note and (
  494. user.role == "admin"
  495. or note.user_id == user.id
  496. or has_access(user.id, "read", note.access_control)
  497. ):
  498. # User has access to the note
  499. query_result = {
  500. "documents": [[note.data.get("content", {}).get("md", "")]],
  501. "metadatas": [[{"file_id": note.id, "name": note.title}]],
  502. }
  503. elif item.get("type") == "chat":
  504. # Chat Attached
  505. chat = Chats.get_chat_by_id(item.get("id"))
  506. if chat and (user.role == "admin" or chat.user_id == user.id):
  507. messages_map = chat.chat.get("history", {}).get("messages", {})
  508. message_id = chat.chat.get("history", {}).get("currentId")
  509. if messages_map and message_id:
  510. # Reconstruct the message list in order
  511. message_list = get_message_list(messages_map, message_id)
  512. message_history = "\n".join(
  513. [
  514. f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n"
  515. for m in message_list
  516. ]
  517. )
  518. # User has access to the chat
  519. query_result = {
  520. "documents": [[message_history]],
  521. "metadatas": [[{"file_id": chat.id, "name": chat.title}]],
  522. }
  523. elif item.get("type") == "url":
  524. content, docs = get_content_from_url(request, item.get("url"))
  525. if docs:
  526. query_result = {
  527. "documents": [[content]],
  528. "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
  529. }
  530. elif item.get("type") == "file":
  531. if (
  532. item.get("context") == "full"
  533. or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
  534. ):
  535. if item.get("file", {}).get("data", {}).get("content", ""):
  536. # Manual Full Mode Toggle
  537. # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
  538. query_result = {
  539. "documents": [
  540. [item.get("file", {}).get("data", {}).get("content", "")]
  541. ],
  542. "metadatas": [
  543. [
  544. {
  545. "file_id": item.get("id"),
  546. "name": item.get("name"),
  547. **item.get("file")
  548. .get("data", {})
  549. .get("metadata", {}),
  550. }
  551. ]
  552. ],
  553. }
  554. elif item.get("id"):
  555. file_object = Files.get_file_by_id(item.get("id"))
  556. if file_object:
  557. query_result = {
  558. "documents": [[file_object.data.get("content", "")]],
  559. "metadatas": [
  560. [
  561. {
  562. "file_id": item.get("id"),
  563. "name": file_object.filename,
  564. "source": file_object.filename,
  565. }
  566. ]
  567. ],
  568. }
  569. else:
  570. # Fallback to collection names
  571. if item.get("legacy"):
  572. collection_names.append(f"{item['id']}")
  573. else:
  574. collection_names.append(f"file-{item['id']}")
  575. elif item.get("type") == "collection":
  576. # Manual Full Mode Toggle for Collection
  577. knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
  578. if knowledge_base and (
  579. user.role == "admin"
  580. or knowledge_base.user_id == user.id
  581. or has_access(user.id, "read", knowledge_base.access_control)
  582. ):
  583. if (
  584. item.get("context") == "full"
  585. or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
  586. ):
  587. if knowledge_base and (
  588. user.role == "admin"
  589. or knowledge_base.user_id == user.id
  590. or has_access(user.id, "read", knowledge_base.access_control)
  591. ):
  592. file_ids = knowledge_base.data.get("file_ids", [])
  593. documents = []
  594. metadatas = []
  595. for file_id in file_ids:
  596. file_object = Files.get_file_by_id(file_id)
  597. if file_object:
  598. documents.append(file_object.data.get("content", ""))
  599. metadatas.append(
  600. {
  601. "file_id": file_id,
  602. "name": file_object.filename,
  603. "source": file_object.filename,
  604. }
  605. )
  606. query_result = {
  607. "documents": [documents],
  608. "metadatas": [metadatas],
  609. }
  610. else:
  611. # Fallback to collection names
  612. if item.get("legacy"):
  613. collection_names = item.get("collection_names", [])
  614. else:
  615. collection_names.append(item["id"])
  616. elif item.get("docs"):
  617. # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
  618. query_result = {
  619. "documents": [[doc.get("content") for doc in item.get("docs")]],
  620. "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
  621. }
  622. elif item.get("collection_name"):
  623. # Direct Collection Name
  624. collection_names.append(item["collection_name"])
  625. elif item.get("collection_names"):
  626. # Collection Names List
  627. collection_names.extend(item["collection_names"])
  628. # If query_result is None
  629. # Fallback to collection names and vector search the collections
  630. if query_result is None and collection_names:
  631. collection_names = set(collection_names).difference(extracted_collections)
  632. if not collection_names:
  633. log.debug(f"skipping {item} as it has already been extracted")
  634. continue
  635. try:
  636. if full_context:
  637. query_result = get_all_items_from_collections(collection_names)
  638. else:
  639. query_result = None # Initialize to None
  640. if hybrid_search:
  641. try:
  642. query_result = query_collection_with_hybrid_search(
  643. collection_names=collection_names,
  644. queries=queries,
  645. embedding_function=embedding_function,
  646. k=k,
  647. reranking_function=reranking_function,
  648. k_reranker=k_reranker,
  649. r=r,
  650. hybrid_bm25_weight=hybrid_bm25_weight,
  651. )
  652. except Exception as e:
  653. log.debug(
  654. "Error when using hybrid search, using non hybrid search as fallback."
  655. )
  656. # fallback to non-hybrid search
  657. if not hybrid_search and query_result is None:
  658. query_result = query_collection(
  659. collection_names=collection_names,
  660. queries=queries,
  661. embedding_function=embedding_function,
  662. k=k,
  663. )
  664. except Exception as e:
  665. log.exception(e)
  666. extracted_collections.extend(collection_names)
  667. if query_result:
  668. if "data" in item:
  669. del item["data"]
  670. query_results.append({**query_result, "file": item})
  671. sources = []
  672. for query_result in query_results:
  673. try:
  674. if "documents" in query_result:
  675. if "metadatas" in query_result:
  676. source = {
  677. "source": query_result["file"],
  678. "document": query_result["documents"][0],
  679. "metadata": query_result["metadatas"][0],
  680. }
  681. if "distances" in query_result and query_result["distances"]:
  682. source["distances"] = query_result["distances"][0]
  683. sources.append(source)
  684. except Exception as e:
  685. log.exception(e)
  686. return sources
  687. def get_model_path(model: str, update_model: bool = False):
  688. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  689. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  690. local_files_only = not update_model
  691. if OFFLINE_MODE:
  692. local_files_only = True
  693. snapshot_kwargs = {
  694. "cache_dir": cache_dir,
  695. "local_files_only": local_files_only,
  696. }
  697. log.debug(f"model: {model}")
  698. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  699. # Inspiration from upstream sentence_transformers
  700. if (
  701. os.path.exists(model)
  702. or ("\\" in model or model.count("/") > 1)
  703. and local_files_only
  704. ):
  705. # If fully qualified path exists, return input, else set repo_id
  706. return model
  707. elif "/" not in model:
  708. # Set valid repo_id for model short-name
  709. model = "sentence-transformers" + "/" + model
  710. snapshot_kwargs["repo_id"] = model
  711. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  712. try:
  713. model_repo_path = snapshot_download(**snapshot_kwargs)
  714. log.debug(f"model_repo_path: {model_repo_path}")
  715. return model_repo_path
  716. except Exception as e:
  717. log.exception(f"Cannot determine model snapshot path: {e}")
  718. return model
  719. def generate_openai_batch_embeddings(
  720. model: str,
  721. texts: list[str],
  722. url: str = "https://api.openai.com/v1",
  723. key: str = "",
  724. prefix: str = None,
  725. user: UserModel = None,
  726. ) -> Optional[list[list[float]]]:
  727. try:
  728. log.debug(
  729. f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
  730. )
  731. json_data = {"input": texts, "model": model}
  732. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  733. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  734. r = requests.post(
  735. f"{url}/embeddings",
  736. headers={
  737. "Content-Type": "application/json",
  738. "Authorization": f"Bearer {key}",
  739. **(
  740. {
  741. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  742. "X-OpenWebUI-User-Id": user.id,
  743. "X-OpenWebUI-User-Email": user.email,
  744. "X-OpenWebUI-User-Role": user.role,
  745. }
  746. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  747. else {}
  748. ),
  749. },
  750. json=json_data,
  751. )
  752. r.raise_for_status()
  753. data = r.json()
  754. if "data" in data:
  755. return [elem["embedding"] for elem in data["data"]]
  756. else:
  757. raise "Something went wrong :/"
  758. except Exception as e:
  759. log.exception(f"Error generating openai batch embeddings: {e}")
  760. return None
  761. def generate_azure_openai_batch_embeddings(
  762. model: str,
  763. texts: list[str],
  764. url: str,
  765. key: str = "",
  766. version: str = "",
  767. prefix: str = None,
  768. user: UserModel = None,
  769. ) -> Optional[list[list[float]]]:
  770. try:
  771. log.debug(
  772. f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
  773. )
  774. json_data = {"input": texts}
  775. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  776. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  777. url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
  778. for _ in range(5):
  779. r = requests.post(
  780. url,
  781. headers={
  782. "Content-Type": "application/json",
  783. "api-key": key,
  784. **(
  785. {
  786. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  787. "X-OpenWebUI-User-Id": user.id,
  788. "X-OpenWebUI-User-Email": user.email,
  789. "X-OpenWebUI-User-Role": user.role,
  790. }
  791. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  792. else {}
  793. ),
  794. },
  795. json=json_data,
  796. )
  797. if r.status_code == 429:
  798. retry = float(r.headers.get("Retry-After", "1"))
  799. time.sleep(retry)
  800. continue
  801. r.raise_for_status()
  802. data = r.json()
  803. if "data" in data:
  804. return [elem["embedding"] for elem in data["data"]]
  805. else:
  806. raise Exception("Something went wrong :/")
  807. return None
  808. except Exception as e:
  809. log.exception(f"Error generating azure openai batch embeddings: {e}")
  810. return None
  811. def generate_ollama_batch_embeddings(
  812. model: str,
  813. texts: list[str],
  814. url: str,
  815. key: str = "",
  816. prefix: str = None,
  817. user: UserModel = None,
  818. ) -> Optional[list[list[float]]]:
  819. try:
  820. log.debug(
  821. f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
  822. )
  823. json_data = {"input": texts, "model": model}
  824. if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
  825. json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
  826. r = requests.post(
  827. f"{url}/api/embed",
  828. headers={
  829. "Content-Type": "application/json",
  830. "Authorization": f"Bearer {key}",
  831. **(
  832. {
  833. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  834. "X-OpenWebUI-User-Id": user.id,
  835. "X-OpenWebUI-User-Email": user.email,
  836. "X-OpenWebUI-User-Role": user.role,
  837. }
  838. if ENABLE_FORWARD_USER_INFO_HEADERS
  839. else {}
  840. ),
  841. },
  842. json=json_data,
  843. )
  844. r.raise_for_status()
  845. data = r.json()
  846. if "embeddings" in data:
  847. return data["embeddings"]
  848. else:
  849. raise "Something went wrong :/"
  850. except Exception as e:
  851. log.exception(f"Error generating ollama batch embeddings: {e}")
  852. return None
  853. def generate_embeddings(
  854. engine: str,
  855. model: str,
  856. text: Union[str, list[str]],
  857. prefix: Union[str, None] = None,
  858. **kwargs,
  859. ):
  860. url = kwargs.get("url", "")
  861. key = kwargs.get("key", "")
  862. user = kwargs.get("user")
  863. if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
  864. if isinstance(text, list):
  865. text = [f"{prefix}{text_element}" for text_element in text]
  866. else:
  867. text = f"{prefix}{text}"
  868. if engine == "ollama":
  869. embeddings = generate_ollama_batch_embeddings(
  870. **{
  871. "model": model,
  872. "texts": text if isinstance(text, list) else [text],
  873. "url": url,
  874. "key": key,
  875. "prefix": prefix,
  876. "user": user,
  877. }
  878. )
  879. return embeddings[0] if isinstance(text, str) else embeddings
  880. elif engine == "openai":
  881. embeddings = generate_openai_batch_embeddings(
  882. model, text if isinstance(text, list) else [text], url, key, prefix, user
  883. )
  884. return embeddings[0] if isinstance(text, str) else embeddings
  885. elif engine == "azure_openai":
  886. azure_api_version = kwargs.get("azure_api_version", "")
  887. embeddings = generate_azure_openai_batch_embeddings(
  888. model,
  889. text if isinstance(text, list) else [text],
  890. url,
  891. key,
  892. azure_api_version,
  893. prefix,
  894. user,
  895. )
  896. return embeddings[0] if isinstance(text, str) else embeddings
  897. import operator
  898. from typing import Optional, Sequence
  899. from langchain_core.callbacks import Callbacks
  900. from langchain_core.documents import BaseDocumentCompressor, Document
  901. class RerankCompressor(BaseDocumentCompressor):
  902. embedding_function: Any
  903. top_n: int
  904. reranking_function: Any
  905. r_score: float
  906. class Config:
  907. extra = "forbid"
  908. arbitrary_types_allowed = True
  909. def compress_documents(
  910. self,
  911. documents: Sequence[Document],
  912. query: str,
  913. callbacks: Optional[Callbacks] = None,
  914. ) -> Sequence[Document]:
  915. reranking = self.reranking_function is not None
  916. scores = None
  917. if reranking:
  918. scores = self.reranking_function(query, documents)
  919. else:
  920. from sentence_transformers import util
  921. query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
  922. document_embedding = self.embedding_function(
  923. [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
  924. )
  925. scores = util.cos_sim(query_embedding, document_embedding)[0]
  926. if scores is not None:
  927. docs_with_scores = list(
  928. zip(
  929. documents,
  930. scores.tolist() if not isinstance(scores, list) else scores,
  931. )
  932. )
  933. if self.r_score:
  934. docs_with_scores = [
  935. (d, s) for d, s in docs_with_scores if s >= self.r_score
  936. ]
  937. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
  938. final_results = []
  939. for doc, doc_score in result[: self.top_n]:
  940. metadata = doc.metadata
  941. metadata["score"] = doc_score
  942. doc = Document(
  943. page_content=doc.page_content,
  944. metadata=metadata,
  945. )
  946. final_results.append(doc)
  947. return final_results
  948. else:
  949. log.warning(
  950. "No valid scores found, check your reranking function. Returning original documents."
  951. )
  952. return documents