chroma.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import chromadb
  2. import logging
  3. from chromadb import Settings
  4. from chromadb.utils.batch_utils import create_batches
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import (
  7. VectorDBBase,
  8. VectorItem,
  9. SearchResult,
  10. GetResult,
  11. )
  12. from open_webui.retrieval.vector.utils import stringify_metadata
  13. from open_webui.config import (
  14. CHROMA_DATA_PATH,
  15. CHROMA_HTTP_HOST,
  16. CHROMA_HTTP_PORT,
  17. CHROMA_HTTP_HEADERS,
  18. CHROMA_HTTP_SSL,
  19. CHROMA_TENANT,
  20. CHROMA_DATABASE,
  21. CHROMA_CLIENT_AUTH_PROVIDER,
  22. CHROMA_CLIENT_AUTH_CREDENTIALS,
  23. )
  24. from open_webui.env import SRC_LOG_LEVELS
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. class ChromaClient(VectorDBBase):
  28. def __init__(self):
  29. settings_dict = {
  30. "allow_reset": True,
  31. "anonymized_telemetry": False,
  32. }
  33. if CHROMA_CLIENT_AUTH_PROVIDER is not None:
  34. settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
  35. if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
  36. settings_dict["chroma_client_auth_credentials"] = (
  37. CHROMA_CLIENT_AUTH_CREDENTIALS
  38. )
  39. if CHROMA_HTTP_HOST != "":
  40. self.client = chromadb.HttpClient(
  41. host=CHROMA_HTTP_HOST,
  42. port=CHROMA_HTTP_PORT,
  43. headers=CHROMA_HTTP_HEADERS,
  44. ssl=CHROMA_HTTP_SSL,
  45. tenant=CHROMA_TENANT,
  46. database=CHROMA_DATABASE,
  47. settings=Settings(**settings_dict),
  48. )
  49. else:
  50. self.client = chromadb.PersistentClient(
  51. path=CHROMA_DATA_PATH,
  52. settings=Settings(**settings_dict),
  53. tenant=CHROMA_TENANT,
  54. database=CHROMA_DATABASE,
  55. )
  56. def has_collection(self, collection_name: str) -> bool:
  57. # Check if the collection exists based on the collection name.
  58. collection_names = self.client.list_collections()
  59. return collection_name in collection_names
  60. def delete_collection(self, collection_name: str):
  61. # Delete the collection based on the collection name.
  62. return self.client.delete_collection(name=collection_name)
  63. def search(
  64. self, collection_name: str, vectors: list[list[float | int]], limit: int
  65. ) -> Optional[SearchResult]:
  66. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  67. try:
  68. collection = self.client.get_collection(name=collection_name)
  69. if collection:
  70. result = collection.query(
  71. query_embeddings=vectors,
  72. n_results=limit,
  73. )
  74. # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
  75. # https://docs.trychroma.com/docs/collections/configure cosine equation
  76. distances: list = result["distances"][0]
  77. distances = [2 - dist for dist in distances]
  78. distances = [[dist / 2 for dist in distances]]
  79. return SearchResult(
  80. **{
  81. "ids": result["ids"],
  82. "distances": distances,
  83. "documents": result["documents"],
  84. "metadatas": result["metadatas"],
  85. }
  86. )
  87. return None
  88. except Exception as e:
  89. return None
  90. def query(
  91. self, collection_name: str, filter: dict, limit: Optional[int] = None
  92. ) -> Optional[GetResult]:
  93. # Query the items from the collection based on the filter.
  94. try:
  95. collection = self.client.get_collection(name=collection_name)
  96. if collection:
  97. result = collection.get(
  98. where=filter,
  99. limit=limit,
  100. )
  101. return GetResult(
  102. **{
  103. "ids": [result["ids"]],
  104. "documents": [result["documents"]],
  105. "metadatas": [result["metadatas"]],
  106. }
  107. )
  108. return None
  109. except:
  110. return None
  111. def get(self, collection_name: str) -> Optional[GetResult]:
  112. # Get all the items in the collection.
  113. collection = self.client.get_collection(name=collection_name)
  114. if collection:
  115. result = collection.get()
  116. return GetResult(
  117. **{
  118. "ids": [result["ids"]],
  119. "documents": [result["documents"]],
  120. "metadatas": [result["metadatas"]],
  121. }
  122. )
  123. return None
  124. def insert(self, collection_name: str, items: list[VectorItem]):
  125. # Insert the items into the collection, if the collection does not exist, it will be created.
  126. collection = self.client.get_or_create_collection(
  127. name=collection_name, metadata={"hnsw:space": "cosine"}
  128. )
  129. ids = [item["id"] for item in items]
  130. documents = [item["text"] for item in items]
  131. embeddings = [item["vector"] for item in items]
  132. metadatas = [stringify_metadata(item["metadata"]) for item in items]
  133. for batch in create_batches(
  134. api=self.client,
  135. documents=documents,
  136. embeddings=embeddings,
  137. ids=ids,
  138. metadatas=metadatas,
  139. ):
  140. collection.add(*batch)
  141. def upsert(self, collection_name: str, items: list[VectorItem]):
  142. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  143. collection = self.client.get_or_create_collection(
  144. name=collection_name, metadata={"hnsw:space": "cosine"}
  145. )
  146. ids = [item["id"] for item in items]
  147. documents = [item["text"] for item in items]
  148. embeddings = [item["vector"] for item in items]
  149. metadatas = [stringify_metadata(item["metadata"]) for item in items]
  150. collection.upsert(
  151. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  152. )
  153. def delete(
  154. self,
  155. collection_name: str,
  156. ids: Optional[list[str]] = None,
  157. filter: Optional[dict] = None,
  158. ):
  159. # Delete the items from the collection based on the ids.
  160. try:
  161. collection = self.client.get_collection(name=collection_name)
  162. if collection:
  163. if ids:
  164. collection.delete(ids=ids)
  165. elif filter:
  166. collection.delete(where=filter)
  167. except Exception as e:
  168. # If collection doesn't exist, that's fine - nothing to delete
  169. log.debug(
  170. f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
  171. )
  172. pass
  173. def reset(self):
  174. # Resets the database. This will delete all collections and item entries.
  175. return self.client.reset()