opensearch.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from opensearchpy import OpenSearch
  2. from opensearchpy.helpers import bulk
  3. from typing import Optional
  4. from open_webui.retrieval.vector.main import (
  5. VectorDBBase,
  6. VectorItem,
  7. SearchResult,
  8. GetResult,
  9. )
  10. from open_webui.config import (
  11. OPENSEARCH_URI,
  12. OPENSEARCH_SSL,
  13. OPENSEARCH_CERT_VERIFY,
  14. OPENSEARCH_USERNAME,
  15. OPENSEARCH_PASSWORD,
  16. )
  17. class OpenSearchClient(VectorDBBase):
  18. def __init__(self):
  19. self.index_prefix = "open_webui"
  20. self.client = OpenSearch(
  21. hosts=[OPENSEARCH_URI],
  22. use_ssl=OPENSEARCH_SSL,
  23. verify_certs=OPENSEARCH_CERT_VERIFY,
  24. http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
  25. )
  26. def _get_index_name(self, collection_name: str) -> str:
  27. return f"{self.index_prefix}_{collection_name}"
  28. def _result_to_get_result(self, result) -> GetResult:
  29. if not result["hits"]["hits"]:
  30. return None
  31. ids = []
  32. documents = []
  33. metadatas = []
  34. for hit in result["hits"]["hits"]:
  35. ids.append(hit["_id"])
  36. documents.append(hit["_source"].get("text"))
  37. metadatas.append(hit["_source"].get("metadata"))
  38. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  39. def _result_to_search_result(self, result) -> SearchResult:
  40. if not result["hits"]["hits"]:
  41. return None
  42. ids = []
  43. distances = []
  44. documents = []
  45. metadatas = []
  46. for hit in result["hits"]["hits"]:
  47. ids.append(hit["_id"])
  48. distances.append(hit["_score"])
  49. documents.append(hit["_source"].get("text"))
  50. metadatas.append(hit["_source"].get("metadata"))
  51. return SearchResult(
  52. ids=[ids],
  53. distances=[distances],
  54. documents=[documents],
  55. metadatas=[metadatas],
  56. )
  57. def _create_index(self, collection_name: str, dimension: int):
  58. body = {
  59. "settings": {"index": {"knn": True}},
  60. "mappings": {
  61. "properties": {
  62. "id": {"type": "keyword"},
  63. "vector": {
  64. "type": "knn_vector",
  65. "dimension": dimension, # Adjust based on your vector dimensions
  66. "index": True,
  67. "similarity": "faiss",
  68. "method": {
  69. "name": "hnsw",
  70. "space_type": "innerproduct", # Use inner product to approximate cosine similarity
  71. "engine": "faiss",
  72. "parameters": {
  73. "ef_construction": 128,
  74. "m": 16,
  75. },
  76. },
  77. },
  78. "text": {"type": "text"},
  79. "metadata": {"type": "object"},
  80. }
  81. },
  82. }
  83. self.client.indices.create(
  84. index=self._get_index_name(collection_name), body=body
  85. )
  86. def _create_batches(self, items: list[VectorItem], batch_size=100):
  87. for i in range(0, len(items), batch_size):
  88. yield items[i : i + batch_size]
  89. def has_collection(self, collection_name: str) -> bool:
  90. # has_collection here means has index.
  91. # We are simply adapting to the norms of the other DBs.
  92. return self.client.indices.exists(index=self._get_index_name(collection_name))
  93. def delete_collection(self, collection_name: str):
  94. # delete_collection here means delete index.
  95. # We are simply adapting to the norms of the other DBs.
  96. self.client.indices.delete(index=self._get_index_name(collection_name))
  97. def search(
  98. self, collection_name: str, vectors: list[list[float | int]], limit: int
  99. ) -> Optional[SearchResult]:
  100. try:
  101. if not self.has_collection(collection_name):
  102. return None
  103. query = {
  104. "size": limit,
  105. "_source": ["text", "metadata"],
  106. "query": {
  107. "script_score": {
  108. "query": {"match_all": {}},
  109. "script": {
  110. "source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
  111. "params": {
  112. "field": "vector",
  113. "query_value": vectors[0],
  114. }, # Assuming single query vector
  115. },
  116. }
  117. },
  118. }
  119. result = self.client.search(
  120. index=self._get_index_name(collection_name), body=query
  121. )
  122. return self._result_to_search_result(result)
  123. except Exception as e:
  124. return None
  125. def query(
  126. self, collection_name: str, filter: dict, limit: Optional[int] = None
  127. ) -> Optional[GetResult]:
  128. if not self.has_collection(collection_name):
  129. return None
  130. query_body = {
  131. "query": {"bool": {"filter": []}},
  132. "_source": ["text", "metadata"],
  133. }
  134. for field, value in filter.items():
  135. query_body["query"]["bool"]["filter"].append(
  136. {"match": {"metadata." + str(field): value}}
  137. )
  138. size = limit if limit else 10
  139. try:
  140. result = self.client.search(
  141. index=self._get_index_name(collection_name),
  142. body=query_body,
  143. size=size,
  144. )
  145. return self._result_to_get_result(result)
  146. except Exception as e:
  147. return None
  148. def _create_index_if_not_exists(self, collection_name: str, dimension: int):
  149. if not self.has_collection(collection_name):
  150. self._create_index(collection_name, dimension)
  151. def get(self, collection_name: str) -> Optional[GetResult]:
  152. query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
  153. result = self.client.search(
  154. index=self._get_index_name(collection_name), body=query
  155. )
  156. return self._result_to_get_result(result)
  157. def insert(self, collection_name: str, items: list[VectorItem]):
  158. self._create_index_if_not_exists(
  159. collection_name=collection_name, dimension=len(items[0]["vector"])
  160. )
  161. for batch in self._create_batches(items):
  162. actions = [
  163. {
  164. "_op_type": "index",
  165. "_index": self._get_index_name(collection_name),
  166. "_id": item["id"],
  167. "_source": {
  168. "vector": item["vector"],
  169. "text": item["text"],
  170. "metadata": item["metadata"],
  171. },
  172. }
  173. for item in batch
  174. ]
  175. bulk(self.client, actions)
  176. def upsert(self, collection_name: str, items: list[VectorItem]):
  177. self._create_index_if_not_exists(
  178. collection_name=collection_name, dimension=len(items[0]["vector"])
  179. )
  180. for batch in self._create_batches(items):
  181. actions = [
  182. {
  183. "_op_type": "update",
  184. "_index": self._get_index_name(collection_name),
  185. "_id": item["id"],
  186. "doc": {
  187. "vector": item["vector"],
  188. "text": item["text"],
  189. "metadata": item["metadata"],
  190. },
  191. "doc_as_upsert": True,
  192. }
  193. for item in batch
  194. ]
  195. bulk(self.client, actions)
  196. def delete(
  197. self,
  198. collection_name: str,
  199. ids: Optional[list[str]] = None,
  200. filter: Optional[dict] = None,
  201. ):
  202. if ids:
  203. actions = [
  204. {
  205. "_op_type": "delete",
  206. "_index": self._get_index_name(collection_name),
  207. "_id": id,
  208. }
  209. for id in ids
  210. ]
  211. bulk(self.client, actions)
  212. elif filter:
  213. query_body = {
  214. "query": {"bool": {"filter": []}},
  215. }
  216. for field, value in filter.items():
  217. query_body["query"]["bool"]["filter"].append(
  218. {"match": {"metadata." + str(field): value}}
  219. )
  220. self.client.delete_by_query(
  221. index=self._get_index_name(collection_name), body=query_body
  222. )
  223. def reset(self):
  224. indices = self.client.indices.get(index=f"{self.index_prefix}_*")
  225. for index in indices:
  226. self.client.indices.delete(index=index)