elasticsearch.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from elasticsearch import Elasticsearch, BadRequestError
  2. from typing import Optional
  3. import ssl
  4. from elasticsearch.helpers import bulk, scan
  5. from open_webui.retrieval.vector.main import (
  6. VectorDBBase,
  7. VectorItem,
  8. SearchResult,
  9. GetResult,
  10. )
  11. from open_webui.config import (
  12. ELASTICSEARCH_URL,
  13. ELASTICSEARCH_CA_CERTS,
  14. ELASTICSEARCH_API_KEY,
  15. ELASTICSEARCH_USERNAME,
  16. ELASTICSEARCH_PASSWORD,
  17. ELASTICSEARCH_CLOUD_ID,
  18. ELASTICSEARCH_INDEX_PREFIX,
  19. SSL_ASSERT_FINGERPRINT,
  20. )
  21. class ElasticsearchClient(VectorDBBase):
  22. """
  23. Important:
  24. in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
  25. an index for each file but store it as a text field, while seperating to different index
  26. baesd on the embedding length.
  27. """
  28. def __init__(self):
  29. self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
  30. self.client = Elasticsearch(
  31. hosts=[ELASTICSEARCH_URL],
  32. ca_certs=ELASTICSEARCH_CA_CERTS,
  33. api_key=ELASTICSEARCH_API_KEY,
  34. cloud_id=ELASTICSEARCH_CLOUD_ID,
  35. basic_auth=(
  36. (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
  37. if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
  38. else None
  39. ),
  40. ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
  41. )
  42. # Status: works
  43. def _get_index_name(self, dimension: int) -> str:
  44. return f"{self.index_prefix}_d{str(dimension)}"
  45. # Status: works
  46. def _scan_result_to_get_result(self, result) -> GetResult:
  47. if not result:
  48. return None
  49. ids = []
  50. documents = []
  51. metadatas = []
  52. for hit in result:
  53. ids.append(hit["_id"])
  54. documents.append(hit["_source"].get("text"))
  55. metadatas.append(hit["_source"].get("metadata"))
  56. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  57. # Status: works
  58. def _result_to_get_result(self, result) -> GetResult:
  59. if not result["hits"]["hits"]:
  60. return None
  61. ids = []
  62. documents = []
  63. metadatas = []
  64. for hit in result["hits"]["hits"]:
  65. ids.append(hit["_id"])
  66. documents.append(hit["_source"].get("text"))
  67. metadatas.append(hit["_source"].get("metadata"))
  68. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  69. # Status: works
  70. def _result_to_search_result(self, result) -> SearchResult:
  71. ids = []
  72. distances = []
  73. documents = []
  74. metadatas = []
  75. for hit in result["hits"]["hits"]:
  76. ids.append(hit["_id"])
  77. distances.append(hit["_score"])
  78. documents.append(hit["_source"].get("text"))
  79. metadatas.append(hit["_source"].get("metadata"))
  80. return SearchResult(
  81. ids=[ids],
  82. distances=[distances],
  83. documents=[documents],
  84. metadatas=[metadatas],
  85. )
  86. # Status: works
  87. def _create_index(self, dimension: int):
  88. body = {
  89. "mappings": {
  90. "dynamic_templates": [
  91. {
  92. "strings": {
  93. "match_mapping_type": "string",
  94. "mapping": {"type": "keyword"},
  95. }
  96. }
  97. ],
  98. "properties": {
  99. "collection": {"type": "keyword"},
  100. "id": {"type": "keyword"},
  101. "vector": {
  102. "type": "dense_vector",
  103. "dims": dimension, # Adjust based on your vector dimensions
  104. "index": True,
  105. "similarity": "cosine",
  106. },
  107. "text": {"type": "text"},
  108. "metadata": {"type": "object"},
  109. },
  110. }
  111. }
  112. self.client.indices.create(index=self._get_index_name(dimension), body=body)
  113. # Status: works
  114. def _create_batches(self, items: list[VectorItem], batch_size=100):
  115. for i in range(0, len(items), batch_size):
  116. yield items[i : min(i + batch_size, len(items))]
  117. # Status: works
  118. def has_collection(self, collection_name) -> bool:
  119. query_body = {"query": {"bool": {"filter": []}}}
  120. query_body["query"]["bool"]["filter"].append(
  121. {"term": {"collection": collection_name}}
  122. )
  123. try:
  124. result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
  125. return result.body["count"] > 0
  126. except Exception as e:
  127. return None
  128. def delete_collection(self, collection_name: str):
  129. query = {"query": {"term": {"collection": collection_name}}}
  130. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  131. # Status: works
  132. def search(
  133. self, collection_name: str, vectors: list[list[float]], limit: int
  134. ) -> Optional[SearchResult]:
  135. query = {
  136. "size": limit,
  137. "_source": ["text", "metadata"],
  138. "query": {
  139. "script_score": {
  140. "query": {
  141. "bool": {"filter": [{"term": {"collection": collection_name}}]}
  142. },
  143. "script": {
  144. "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
  145. "params": {
  146. "vector": vectors[0]
  147. }, # Assuming single query vector
  148. },
  149. }
  150. },
  151. }
  152. result = self.client.search(
  153. index=self._get_index_name(len(vectors[0])), body=query
  154. )
  155. return self._result_to_search_result(result)
  156. # Status: only tested halfwat
  157. def query(
  158. self, collection_name: str, filter: dict, limit: Optional[int] = None
  159. ) -> Optional[GetResult]:
  160. if not self.has_collection(collection_name):
  161. return None
  162. query_body = {
  163. "query": {"bool": {"filter": []}},
  164. "_source": ["text", "metadata"],
  165. }
  166. for field, value in filter.items():
  167. query_body["query"]["bool"]["filter"].append({"term": {field: value}})
  168. query_body["query"]["bool"]["filter"].append(
  169. {"term": {"collection": collection_name}}
  170. )
  171. size = limit if limit else 10
  172. try:
  173. result = self.client.search(
  174. index=f"{self.index_prefix}*",
  175. body=query_body,
  176. size=size,
  177. )
  178. return self._result_to_get_result(result)
  179. except Exception as e:
  180. return None
  181. # Status: works
  182. def _has_index(self, dimension: int):
  183. return self.client.indices.exists(
  184. index=self._get_index_name(dimension=dimension)
  185. )
  186. def get_or_create_index(self, dimension: int):
  187. if not self._has_index(dimension=dimension):
  188. self._create_index(dimension=dimension)
  189. # Status: works
  190. def get(self, collection_name: str) -> Optional[GetResult]:
  191. # Get all the items in the collection.
  192. query = {
  193. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
  194. "_source": ["text", "metadata"],
  195. }
  196. results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
  197. return self._scan_result_to_get_result(results)
  198. # Status: works
  199. def insert(self, collection_name: str, items: list[VectorItem]):
  200. if not self._has_index(dimension=len(items[0]["vector"])):
  201. self._create_index(dimension=len(items[0]["vector"]))
  202. for batch in self._create_batches(items):
  203. actions = [
  204. {
  205. "_index": self._get_index_name(dimension=len(items[0]["vector"])),
  206. "_id": item["id"],
  207. "_source": {
  208. "collection": collection_name,
  209. "vector": item["vector"],
  210. "text": item["text"],
  211. "metadata": item["metadata"],
  212. },
  213. }
  214. for item in batch
  215. ]
  216. bulk(self.client, actions)
  217. # Upsert documents using the update API with doc_as_upsert=True.
  218. def upsert(self, collection_name: str, items: list[VectorItem]):
  219. if not self._has_index(dimension=len(items[0]["vector"])):
  220. self._create_index(dimension=len(items[0]["vector"]))
  221. for batch in self._create_batches(items):
  222. actions = [
  223. {
  224. "_op_type": "update",
  225. "_index": self._get_index_name(dimension=len(item["vector"])),
  226. "_id": item["id"],
  227. "doc": {
  228. "collection": collection_name,
  229. "vector": item["vector"],
  230. "text": item["text"],
  231. "metadata": item["metadata"],
  232. },
  233. "doc_as_upsert": True,
  234. }
  235. for item in batch
  236. ]
  237. bulk(self.client, actions)
  238. # Delete specific documents from a collection by filtering on both collection and document IDs.
  239. def delete(
  240. self,
  241. collection_name: str,
  242. ids: Optional[list[str]] = None,
  243. filter: Optional[dict] = None,
  244. ):
  245. query = {
  246. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
  247. }
  248. # logic based on chromaDB
  249. if ids:
  250. query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
  251. elif filter:
  252. for field, value in filter.items():
  253. query["query"]["bool"]["filter"].append(
  254. {"term": {f"metadata.{field}": value}}
  255. )
  256. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  257. def reset(self):
  258. indices = self.client.indices.get(index=f"{self.index_prefix}*")
  259. for index in indices:
  260. self.client.indices.delete(index=index)