chroma.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import chromadb
  2. import logging
  3. from chromadb import Settings
  4. from chromadb.utils.batch_utils import create_batches
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  7. from open_webui.config import (
  8. CHROMA_DATA_PATH,
  9. CHROMA_HTTP_HOST,
  10. CHROMA_HTTP_PORT,
  11. CHROMA_HTTP_HEADERS,
  12. CHROMA_HTTP_SSL,
  13. CHROMA_TENANT,
  14. CHROMA_DATABASE,
  15. CHROMA_CLIENT_AUTH_PROVIDER,
  16. CHROMA_CLIENT_AUTH_CREDENTIALS,
  17. )
  18. from open_webui.env import SRC_LOG_LEVELS
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["RAG"])
  21. class ChromaClient:
  22. def __init__(self):
  23. settings_dict = {
  24. "allow_reset": True,
  25. "anonymized_telemetry": False,
  26. }
  27. if CHROMA_CLIENT_AUTH_PROVIDER is not None:
  28. settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
  29. if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
  30. settings_dict["chroma_client_auth_credentials"] = (
  31. CHROMA_CLIENT_AUTH_CREDENTIALS
  32. )
  33. if CHROMA_HTTP_HOST != "":
  34. self.client = chromadb.HttpClient(
  35. host=CHROMA_HTTP_HOST,
  36. port=CHROMA_HTTP_PORT,
  37. headers=CHROMA_HTTP_HEADERS,
  38. ssl=CHROMA_HTTP_SSL,
  39. tenant=CHROMA_TENANT,
  40. database=CHROMA_DATABASE,
  41. settings=Settings(**settings_dict),
  42. )
  43. else:
  44. self.client = chromadb.PersistentClient(
  45. path=CHROMA_DATA_PATH,
  46. settings=Settings(**settings_dict),
  47. tenant=CHROMA_TENANT,
  48. database=CHROMA_DATABASE,
  49. )
  50. def has_collection(self, collection_name: str) -> bool:
  51. # Check if the collection exists based on the collection name.
  52. collection_names = self.client.list_collections()
  53. return collection_name in collection_names
  54. def delete_collection(self, collection_name: str):
  55. # Delete the collection based on the collection name.
  56. return self.client.delete_collection(name=collection_name)
  57. def search(
  58. self, collection_name: str, vectors: list[list[float | int]], limit: int
  59. ) -> Optional[SearchResult]:
  60. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  61. try:
  62. collection = self.client.get_collection(name=collection_name)
  63. if collection:
  64. result = collection.query(
  65. query_embeddings=vectors,
  66. n_results=limit,
  67. )
  68. # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
  69. # https://docs.trychroma.com/docs/collections/configure cosine equation
  70. distances: list = result["distances"][0]
  71. distances = [2 - dist for dist in distances]
  72. distances = [[dist / 2 for dist in distances]]
  73. return SearchResult(
  74. **{
  75. "ids": result["ids"],
  76. "distances": distances,
  77. "documents": result["documents"],
  78. "metadatas": result["metadatas"],
  79. }
  80. )
  81. return None
  82. except Exception as e:
  83. return None
  84. def query(
  85. self, collection_name: str, filter: dict, limit: Optional[int] = None
  86. ) -> Optional[GetResult]:
  87. # Query the items from the collection based on the filter.
  88. try:
  89. collection = self.client.get_collection(name=collection_name)
  90. if collection:
  91. result = collection.get(
  92. where=filter,
  93. limit=limit,
  94. )
  95. return GetResult(
  96. **{
  97. "ids": [result["ids"]],
  98. "documents": [result["documents"]],
  99. "metadatas": [result["metadatas"]],
  100. }
  101. )
  102. return None
  103. except:
  104. return None
  105. def get(self, collection_name: str) -> Optional[GetResult]:
  106. # Get all the items in the collection.
  107. collection = self.client.get_collection(name=collection_name)
  108. if collection:
  109. result = collection.get()
  110. return GetResult(
  111. **{
  112. "ids": [result["ids"]],
  113. "documents": [result["documents"]],
  114. "metadatas": [result["metadatas"]],
  115. }
  116. )
  117. return None
  118. def insert(self, collection_name: str, items: list[VectorItem]):
  119. # Insert the items into the collection, if the collection does not exist, it will be created.
  120. collection = self.client.get_or_create_collection(
  121. name=collection_name, metadata={"hnsw:space": "cosine"}
  122. )
  123. ids = [item["id"] for item in items]
  124. documents = [item["text"] for item in items]
  125. embeddings = [item["vector"] for item in items]
  126. metadatas = [item["metadata"] for item in items]
  127. for batch in create_batches(
  128. api=self.client,
  129. documents=documents,
  130. embeddings=embeddings,
  131. ids=ids,
  132. metadatas=metadatas,
  133. ):
  134. collection.add(*batch)
  135. def upsert(self, collection_name: str, items: list[VectorItem]):
  136. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  137. collection = self.client.get_or_create_collection(
  138. name=collection_name, metadata={"hnsw:space": "cosine"}
  139. )
  140. ids = [item["id"] for item in items]
  141. documents = [item["text"] for item in items]
  142. embeddings = [item["vector"] for item in items]
  143. metadatas = [item["metadata"] for item in items]
  144. collection.upsert(
  145. ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
  146. )
  147. def delete(
  148. self,
  149. collection_name: str,
  150. ids: Optional[list[str]] = None,
  151. filter: Optional[dict] = None,
  152. ):
  153. # Delete the items from the collection based on the ids.
  154. try:
  155. collection = self.client.get_collection(name=collection_name)
  156. if collection:
  157. if ids:
  158. collection.delete(ids=ids)
  159. elif filter:
  160. collection.delete(where=filter)
  161. except Exception as e:
  162. # If collection doesn't exist, that's fine - nothing to delete
  163. log.debug(
  164. f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
  165. )
  166. pass
  167. def reset(self):
  168. # Resets the database. This will delete all collections and item entries.
  169. return self.client.reset()