opensearch.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from opensearchpy import OpenSearch
  2. from opensearchpy.helpers import bulk
  3. from typing import Optional
  4. from open_webui.retrieval.vector.utils import stringify_metadata
  5. from open_webui.retrieval.vector.main import (
  6. VectorDBBase,
  7. VectorItem,
  8. SearchResult,
  9. GetResult,
  10. )
  11. from open_webui.config import (
  12. OPENSEARCH_URI,
  13. OPENSEARCH_SSL,
  14. OPENSEARCH_CERT_VERIFY,
  15. OPENSEARCH_USERNAME,
  16. OPENSEARCH_PASSWORD,
  17. )
  18. class OpenSearchClient(VectorDBBase):
  19. def __init__(self):
  20. self.index_prefix = "open_webui"
  21. self.client = OpenSearch(
  22. hosts=[OPENSEARCH_URI],
  23. use_ssl=OPENSEARCH_SSL,
  24. verify_certs=OPENSEARCH_CERT_VERIFY,
  25. http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
  26. )
  27. def _get_index_name(self, collection_name: str) -> str:
  28. return f"{self.index_prefix}_{collection_name}"
  29. def _result_to_get_result(self, result) -> GetResult:
  30. if not result["hits"]["hits"]:
  31. return None
  32. ids = []
  33. documents = []
  34. metadatas = []
  35. for hit in result["hits"]["hits"]:
  36. ids.append(hit["_id"])
  37. documents.append(hit["_source"].get("text"))
  38. metadatas.append(hit["_source"].get("metadata"))
  39. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  40. def _result_to_search_result(self, result) -> SearchResult:
  41. if not result["hits"]["hits"]:
  42. return None
  43. ids = []
  44. distances = []
  45. documents = []
  46. metadatas = []
  47. for hit in result["hits"]["hits"]:
  48. ids.append(hit["_id"])
  49. distances.append(hit["_score"])
  50. documents.append(hit["_source"].get("text"))
  51. metadatas.append(hit["_source"].get("metadata"))
  52. return SearchResult(
  53. ids=[ids],
  54. distances=[distances],
  55. documents=[documents],
  56. metadatas=[metadatas],
  57. )
  58. def _create_index(self, collection_name: str, dimension: int):
  59. body = {
  60. "settings": {"index": {"knn": True}},
  61. "mappings": {
  62. "properties": {
  63. "id": {"type": "keyword"},
  64. "vector": {
  65. "type": "knn_vector",
  66. "dimension": dimension, # Adjust based on your vector dimensions
  67. "index": True,
  68. "similarity": "faiss",
  69. "method": {
  70. "name": "hnsw",
  71. "space_type": "innerproduct", # Use inner product to approximate cosine similarity
  72. "engine": "faiss",
  73. "parameters": {
  74. "ef_construction": 128,
  75. "m": 16,
  76. },
  77. },
  78. },
  79. "text": {"type": "text"},
  80. "metadata": {"type": "object"},
  81. }
  82. },
  83. }
  84. self.client.indices.create(
  85. index=self._get_index_name(collection_name), body=body
  86. )
  87. def _create_batches(self, items: list[VectorItem], batch_size=100):
  88. for i in range(0, len(items), batch_size):
  89. yield items[i : i + batch_size]
  90. def has_collection(self, collection_name: str) -> bool:
  91. # has_collection here means has index.
  92. # We are simply adapting to the norms of the other DBs.
  93. return self.client.indices.exists(index=self._get_index_name(collection_name))
  94. def delete_collection(self, collection_name: str):
  95. # delete_collection here means delete index.
  96. # We are simply adapting to the norms of the other DBs.
  97. self.client.indices.delete(index=self._get_index_name(collection_name))
  98. def search(
  99. self, collection_name: str, vectors: list[list[float | int]], limit: int
  100. ) -> Optional[SearchResult]:
  101. try:
  102. if not self.has_collection(collection_name):
  103. return None
  104. query = {
  105. "size": limit,
  106. "_source": ["text", "metadata"],
  107. "query": {
  108. "script_score": {
  109. "query": {"match_all": {}},
  110. "script": {
  111. "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
  112. "params": {
  113. "field": "vector",
  114. "query_value": vectors[0],
  115. }, # Assuming single query vector
  116. },
  117. }
  118. },
  119. }
  120. result = self.client.search(
  121. index=self._get_index_name(collection_name), body=query
  122. )
  123. return self._result_to_search_result(result)
  124. except Exception as e:
  125. return None
  126. def query(
  127. self, collection_name: str, filter: dict, limit: Optional[int] = None
  128. ) -> Optional[GetResult]:
  129. if not self.has_collection(collection_name):
  130. return None
  131. query_body = {
  132. "query": {"bool": {"filter": []}},
  133. "_source": ["text", "metadata"],
  134. }
  135. for field, value in filter.items():
  136. query_body["query"]["bool"]["filter"].append(
  137. {"term": {"metadata." + str(field) + ".keyword": value}}
  138. )
  139. size = limit if limit else 10000
  140. try:
  141. result = self.client.search(
  142. index=self._get_index_name(collection_name),
  143. body=query_body,
  144. size=size,
  145. )
  146. return self._result_to_get_result(result)
  147. except Exception as e:
  148. return None
  149. def _create_index_if_not_exists(self, collection_name: str, dimension: int):
  150. if not self.has_collection(collection_name):
  151. self._create_index(collection_name, dimension)
  152. def get(self, collection_name: str) -> Optional[GetResult]:
  153. query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
  154. result = self.client.search(
  155. index=self._get_index_name(collection_name), body=query
  156. )
  157. return self._result_to_get_result(result)
  158. def insert(self, collection_name: str, items: list[VectorItem]):
  159. self._create_index_if_not_exists(
  160. collection_name=collection_name, dimension=len(items[0]["vector"])
  161. )
  162. for batch in self._create_batches(items):
  163. actions = [
  164. {
  165. "_op_type": "index",
  166. "_index": self._get_index_name(collection_name),
  167. "_id": item["id"],
  168. "_source": {
  169. "vector": item["vector"],
  170. "text": item["text"],
  171. "metadata": stringify_metadata(item["metadata"]),
  172. },
  173. }
  174. for item in batch
  175. ]
  176. bulk(self.client, actions)
  177. self.client.indices.refresh(self._get_index_name(collection_name))
  178. def upsert(self, collection_name: str, items: list[VectorItem]):
  179. self._create_index_if_not_exists(
  180. collection_name=collection_name, dimension=len(items[0]["vector"])
  181. )
  182. for batch in self._create_batches(items):
  183. actions = [
  184. {
  185. "_op_type": "update",
  186. "_index": self._get_index_name(collection_name),
  187. "_id": item["id"],
  188. "doc": {
  189. "vector": item["vector"],
  190. "text": item["text"],
  191. "metadata": stringify_metadata(item["metadata"]),
  192. },
  193. "doc_as_upsert": True,
  194. }
  195. for item in batch
  196. ]
  197. bulk(self.client, actions)
  198. self.client.indices.refresh(self._get_index_name(collection_name))
  199. def delete(
  200. self,
  201. collection_name: str,
  202. ids: Optional[list[str]] = None,
  203. filter: Optional[dict] = None,
  204. ):
  205. if ids:
  206. actions = [
  207. {
  208. "_op_type": "delete",
  209. "_index": self._get_index_name(collection_name),
  210. "_id": id,
  211. }
  212. for id in ids
  213. ]
  214. bulk(self.client, actions)
  215. elif filter:
  216. query_body = {
  217. "query": {"bool": {"filter": []}},
  218. }
  219. for field, value in filter.items():
  220. query_body["query"]["bool"]["filter"].append(
  221. {"term": {"metadata." + str(field) + ".keyword": value}}
  222. )
  223. self.client.delete_by_query(
  224. index=self._get_index_name(collection_name), body=query_body
  225. )
  226. self.client.indices.refresh(self._get_index_name(collection_name))
  227. def reset(self):
  228. indices = self.client.indices.get(index=f"{self.index_prefix}_*")
  229. for index in indices:
  230. self.client.indices.delete(index=index)