chroma.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import chromadb
  2. import logging
  3. from chromadb import Settings
  4. from chromadb.utils.batch_utils import create_batches
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import (
  7. VectorDBBase,
  8. VectorItem,
  9. SearchResult,
  10. GetResult,
  11. )
  12. from open_webui.config import (
  13. CHROMA_DATA_PATH,
  14. CHROMA_HTTP_HOST,
  15. CHROMA_HTTP_PORT,
  16. CHROMA_HTTP_HEADERS,
  17. CHROMA_HTTP_SSL,
  18. CHROMA_TENANT,
  19. CHROMA_DATABASE,
  20. CHROMA_CLIENT_AUTH_PROVIDER,
  21. CHROMA_CLIENT_AUTH_CREDENTIALS,
  22. )
  23. from open_webui.env import SRC_LOG_LEVELS
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["RAG"])
  26. class ChromaClient(VectorDBBase):
  27. def __init__(self):
  28. settings_dict = {
  29. "allow_reset": True,
  30. "anonymized_telemetry": False,
  31. }
  32. if CHROMA_CLIENT_AUTH_PROVIDER is not None:
  33. settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
  34. if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
  35. settings_dict["chroma_client_auth_credentials"] = (
  36. CHROMA_CLIENT_AUTH_CREDENTIALS
  37. )
  38. if CHROMA_HTTP_HOST != "":
  39. self.client = chromadb.HttpClient(
  40. host=CHROMA_HTTP_HOST,
  41. port=CHROMA_HTTP_PORT,
  42. headers=CHROMA_HTTP_HEADERS,
  43. ssl=CHROMA_HTTP_SSL,
  44. tenant=CHROMA_TENANT,
  45. database=CHROMA_DATABASE,
  46. settings=Settings(**settings_dict),
  47. )
  48. else:
  49. self.client = chromadb.PersistentClient(
  50. path=CHROMA_DATA_PATH,
  51. settings=Settings(**settings_dict),
  52. tenant=CHROMA_TENANT,
  53. database=CHROMA_DATABASE,
  54. )
  55. def has_collection(self, collection_name: str) -> bool:
  56. # Check if the collection exists based on the collection name.
  57. collection_names = self.client.list_collections()
  58. return collection_name in collection_names
  59. def delete_collection(self, collection_name: str):
  60. # Delete the collection based on the collection name.
  61. return self.client.delete_collection(name=collection_name)
  62. def search(
  63. self, collection_name: str, vectors: list[list[float | int]], limit: int
  64. ) -> Optional[SearchResult]:
  65. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  66. try:
  67. collection = self.client.get_collection(name=collection_name)
  68. if collection:
  69. result = collection.query(
  70. query_embeddings=vectors,
  71. n_results=limit,
  72. )
  73. # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
  74. # https://docs.trychroma.com/docs/collections/configure cosine equation
  75. distances: list = result["distances"][0]
  76. distances = [2 - dist for dist in distances]
  77. distances = [[dist / 2 for dist in distances]]
  78. return SearchResult(
  79. **{
  80. "ids": result["ids"],
  81. "distances": distances,
  82. "documents": result["documents"],
  83. "metadatas": result["metadatas"],
  84. }
  85. )
  86. return None
  87. except Exception as e:
  88. return None
  89. def query(
  90. self, collection_name: str, filter: dict, limit: Optional[int] = None
  91. ) -> Optional[GetResult]:
  92. # Query the items from the collection based on the filter.
  93. try:
  94. collection = self.client.get_collection(name=collection_name)
  95. if collection:
  96. result = collection.get(
  97. where=filter,
  98. limit=limit,
  99. )
  100. return GetResult(
  101. **{
  102. "ids": [result["ids"]],
  103. "documents": [result["documents"]],
  104. "metadatas": [result["metadatas"]],
  105. }
  106. )
  107. return None
  108. except:
  109. return None
  110. def get(self, collection_name: str) -> Optional[GetResult]:
  111. # Get all the items in the collection.
  112. collection = self.client.get_collection(name=collection_name)
  113. if collection:
  114. result = collection.get()
  115. return GetResult(
  116. **{
  117. "ids": [result["ids"]],
  118. "documents": [result["documents"]],
  119. "metadatas": [result["metadatas"]],
  120. }
  121. )
  122. return None
  123. def insert(self, collection_name: str, items: list[VectorItem]):
  124. # Insert the items into the collection, if the collection does not exist, it will be created.
  125. collection = self.client.get_or_create_collection(
  126. name=collection_name, metadata={"hnsw:space": "cosine"}
  127. )
  128. ids = [item["id"] for item in items]
  129. documents = [item["text"] for item in items]
  130. embeddings = [item["vector"] for item in items]
  131. metadatas = [item["metadata"] for item in items]
  132. for batch in create_batches(
  133. api=self.client,
  134. documents=documents,
  135. embeddings=embeddings,
  136. ids=ids,
  137. metadatas=metadatas,
  138. ):
  139. collection.add(*batch)
  140. def upsert(self, collection_name: str, items: list[VectorItem]):
  141. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  142. collection = self.client.get_or_create_collection(
  143. name=collection_name, metadata={"hnsw:space": "cosine"}
  144. )
  145. ids = [item["id"] for item in items]
  146. documents = [item["text"] for item in items]
  147. embeddings = [item["vector"] for item in items]
  148. metadatas = [item["metadata"] for item in items]
  149. collection.upsert(
  150. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  151. )
  152. def delete(
  153. self,
  154. collection_name: str,
  155. ids: Optional[list[str]] = None,
  156. filter: Optional[dict] = None,
  157. ):
  158. # Delete the items from the collection based on the ids.
  159. try:
  160. collection = self.client.get_collection(name=collection_name)
  161. if collection:
  162. if ids:
  163. collection.delete(ids=ids)
  164. elif filter:
  165. collection.delete(where=filter)
  166. except Exception as e:
  167. # If collection doesn't exist, that's fine - nothing to delete
  168. log.debug(
  169. f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
  170. )
  171. pass
  172. def reset(self):
  173. # Resets the database. This will delete all collections and item entries.
  174. return self.client.reset()