elasticsearch.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. from elasticsearch import Elasticsearch, BadRequestError
  2. from typing import Optional
  3. import ssl
  4. from elasticsearch.helpers import bulk, scan
  5. from open_webui.retrieval.vector.utils import stringify_metadata
  6. from open_webui.retrieval.vector.main import (
  7. VectorDBBase,
  8. VectorItem,
  9. SearchResult,
  10. GetResult,
  11. )
  12. from open_webui.config import (
  13. ELASTICSEARCH_URL,
  14. ELASTICSEARCH_CA_CERTS,
  15. ELASTICSEARCH_API_KEY,
  16. ELASTICSEARCH_USERNAME,
  17. ELASTICSEARCH_PASSWORD,
  18. ELASTICSEARCH_CLOUD_ID,
  19. ELASTICSEARCH_INDEX_PREFIX,
  20. SSL_ASSERT_FINGERPRINT,
  21. )
  22. class ElasticsearchClient(VectorDBBase):
  23. """
  24. Important:
  25. in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
  26. an index for each file but store it as a text field, while seperating to different index
  27. baesd on the embedding length.
  28. """
  29. def __init__(self):
  30. self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
  31. self.client = Elasticsearch(
  32. hosts=[ELASTICSEARCH_URL],
  33. ca_certs=ELASTICSEARCH_CA_CERTS,
  34. api_key=ELASTICSEARCH_API_KEY,
  35. cloud_id=ELASTICSEARCH_CLOUD_ID,
  36. basic_auth=(
  37. (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
  38. if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
  39. else None
  40. ),
  41. ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
  42. )
  43. # Status: works
  44. def _get_index_name(self, dimension: int) -> str:
  45. return f"{self.index_prefix}_d{str(dimension)}"
  46. # Status: works
  47. def _scan_result_to_get_result(self, result) -> GetResult:
  48. if not result:
  49. return None
  50. ids = []
  51. documents = []
  52. metadatas = []
  53. for hit in result:
  54. ids.append(hit["_id"])
  55. documents.append(hit["_source"].get("text"))
  56. metadatas.append(hit["_source"].get("metadata"))
  57. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  58. # Status: works
  59. def _result_to_get_result(self, result) -> GetResult:
  60. if not result["hits"]["hits"]:
  61. return None
  62. ids = []
  63. documents = []
  64. metadatas = []
  65. for hit in result["hits"]["hits"]:
  66. ids.append(hit["_id"])
  67. documents.append(hit["_source"].get("text"))
  68. metadatas.append(hit["_source"].get("metadata"))
  69. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  70. # Status: works
  71. def _result_to_search_result(self, result) -> SearchResult:
  72. ids = []
  73. distances = []
  74. documents = []
  75. metadatas = []
  76. for hit in result["hits"]["hits"]:
  77. ids.append(hit["_id"])
  78. distances.append(hit["_score"])
  79. documents.append(hit["_source"].get("text"))
  80. metadatas.append(hit["_source"].get("metadata"))
  81. return SearchResult(
  82. ids=[ids],
  83. distances=[distances],
  84. documents=[documents],
  85. metadatas=[metadatas],
  86. )
  87. # Status: works
  88. def _create_index(self, dimension: int):
  89. body = {
  90. "mappings": {
  91. "dynamic_templates": [
  92. {
  93. "strings": {
  94. "match_mapping_type": "string",
  95. "mapping": {"type": "keyword"},
  96. }
  97. }
  98. ],
  99. "properties": {
  100. "collection": {"type": "keyword"},
  101. "id": {"type": "keyword"},
  102. "vector": {
  103. "type": "dense_vector",
  104. "dims": dimension, # Adjust based on your vector dimensions
  105. "index": True,
  106. "similarity": "cosine",
  107. },
  108. "text": {"type": "text"},
  109. "metadata": {"type": "object"},
  110. },
  111. }
  112. }
  113. self.client.indices.create(index=self._get_index_name(dimension), body=body)
  114. # Status: works
  115. def _create_batches(self, items: list[VectorItem], batch_size=100):
  116. for i in range(0, len(items), batch_size):
  117. yield items[i : min(i + batch_size, len(items))]
  118. # Status: works
  119. def has_collection(self, collection_name) -> bool:
  120. query_body = {"query": {"bool": {"filter": []}}}
  121. query_body["query"]["bool"]["filter"].append(
  122. {"term": {"collection": collection_name}}
  123. )
  124. try:
  125. result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
  126. return result.body["count"] > 0
  127. except Exception as e:
  128. return None
  129. def delete_collection(self, collection_name: str):
  130. query = {"query": {"term": {"collection": collection_name}}}
  131. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  132. # Status: works
  133. def search(
  134. self, collection_name: str, vectors: list[list[float]], limit: int
  135. ) -> Optional[SearchResult]:
  136. query = {
  137. "size": limit,
  138. "_source": ["text", "metadata"],
  139. "query": {
  140. "script_score": {
  141. "query": {
  142. "bool": {"filter": [{"term": {"collection": collection_name}}]}
  143. },
  144. "script": {
  145. "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
  146. "params": {
  147. "vector": vectors[0]
  148. }, # Assuming single query vector
  149. },
  150. }
  151. },
  152. }
  153. result = self.client.search(
  154. index=self._get_index_name(len(vectors[0])), body=query
  155. )
  156. return self._result_to_search_result(result)
  157. # Status: only tested halfwat
  158. def query(
  159. self, collection_name: str, filter: dict, limit: Optional[int] = None
  160. ) -> Optional[GetResult]:
  161. if not self.has_collection(collection_name):
  162. return None
  163. query_body = {
  164. "query": {"bool": {"filter": []}},
  165. "_source": ["text", "metadata"],
  166. }
  167. for field, value in filter.items():
  168. query_body["query"]["bool"]["filter"].append({"term": {field: value}})
  169. query_body["query"]["bool"]["filter"].append(
  170. {"term": {"collection": collection_name}}
  171. )
  172. size = limit if limit else 10
  173. try:
  174. result = self.client.search(
  175. index=f"{self.index_prefix}*",
  176. body=query_body,
  177. size=size,
  178. )
  179. return self._result_to_get_result(result)
  180. except Exception as e:
  181. return None
  182. # Status: works
  183. def _has_index(self, dimension: int):
  184. return self.client.indices.exists(
  185. index=self._get_index_name(dimension=dimension)
  186. )
  187. def get_or_create_index(self, dimension: int):
  188. if not self._has_index(dimension=dimension):
  189. self._create_index(dimension=dimension)
  190. # Status: works
  191. def get(self, collection_name: str) -> Optional[GetResult]:
  192. # Get all the items in the collection.
  193. query = {
  194. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
  195. "_source": ["text", "metadata"],
  196. }
  197. results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
  198. return self._scan_result_to_get_result(results)
  199. # Status: works
  200. def insert(self, collection_name: str, items: list[VectorItem]):
  201. if not self._has_index(dimension=len(items[0]["vector"])):
  202. self._create_index(dimension=len(items[0]["vector"]))
  203. for batch in self._create_batches(items):
  204. actions = [
  205. {
  206. "_index": self._get_index_name(dimension=len(items[0]["vector"])),
  207. "_id": item["id"],
  208. "_source": {
  209. "collection": collection_name,
  210. "vector": item["vector"],
  211. "text": item["text"],
  212. "metadata": stringify_metadata(item["metadata"]),
  213. },
  214. }
  215. for item in batch
  216. ]
  217. bulk(self.client, actions)
  218. # Upsert documents using the update API with doc_as_upsert=True.
  219. def upsert(self, collection_name: str, items: list[VectorItem]):
  220. if not self._has_index(dimension=len(items[0]["vector"])):
  221. self._create_index(dimension=len(items[0]["vector"]))
  222. for batch in self._create_batches(items):
  223. actions = [
  224. {
  225. "_op_type": "update",
  226. "_index": self._get_index_name(dimension=len(item["vector"])),
  227. "_id": item["id"],
  228. "doc": {
  229. "collection": collection_name,
  230. "vector": item["vector"],
  231. "text": item["text"],
  232. "metadata": stringify_metadata(item["metadata"]),
  233. },
  234. "doc_as_upsert": True,
  235. }
  236. for item in batch
  237. ]
  238. bulk(self.client, actions)
  239. # Delete specific documents from a collection by filtering on both collection and document IDs.
  240. def delete(
  241. self,
  242. collection_name: str,
  243. ids: Optional[list[str]] = None,
  244. filter: Optional[dict] = None,
  245. ):
  246. query = {
  247. "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
  248. }
  249. # logic based on chromaDB
  250. if ids:
  251. query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
  252. elif filter:
  253. for field, value in filter.items():
  254. query["query"]["bool"]["filter"].append(
  255. {"term": {f"metadata.{field}": value}}
  256. )
  257. self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
  258. def reset(self):
  259. indices = self.client.indices.get(index=f"{self.index_prefix}*")
  260. for index in indices:
  261. self.client.indices.delete(index=index)