qdrant_multitenancy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. QDRANT_COLLECTION_PREFIX,
  12. QDRANT_TIMEOUT,
  13. QDRANT_HNSW_M,
  14. )
  15. from open_webui.env import SRC_LOG_LEVELS
  16. from open_webui.retrieval.vector.main import (
  17. GetResult,
  18. SearchResult,
  19. VectorDBBase,
  20. VectorItem,
  21. )
  22. from qdrant_client import QdrantClient as Qclient
  23. from qdrant_client.http.exceptions import UnexpectedResponse
  24. from qdrant_client.http.models import PointStruct
  25. from qdrant_client.models import models
  26. NO_LIMIT = 999999999
  27. TENANT_ID_FIELD = "tenant_id"
  28. DEFAULT_DIMENSION = 384
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["RAG"])
  31. def _tenant_filter(tenant_id: str) -> models.FieldCondition:
  32. return models.FieldCondition(
  33. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  34. )
  35. def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
  36. return models.FieldCondition(
  37. key=f"metadata.{key}", match=models.MatchValue(value=value)
  38. )
  39. class QdrantClient(VectorDBBase):
  40. def __init__(self):
  41. self.collection_prefix = QDRANT_COLLECTION_PREFIX
  42. self.QDRANT_URI = QDRANT_URI
  43. self.QDRANT_API_KEY = QDRANT_API_KEY
  44. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  45. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  46. self.GRPC_PORT = QDRANT_GRPC_PORT
  47. self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
  48. self.QDRANT_HNSW_M = QDRANT_HNSW_M
  49. if not self.QDRANT_URI:
  50. raise ValueError(
  51. "QDRANT_URI is not set. Please configure it in the environment variables."
  52. )
  53. # Unified handling for either scheme
  54. parsed = urlparse(self.QDRANT_URI)
  55. host = parsed.hostname or self.QDRANT_URI
  56. http_port = parsed.port or 6333 # default REST port
  57. self.client = (
  58. Qclient(
  59. host=host,
  60. port=http_port,
  61. grpc_port=self.GRPC_PORT,
  62. prefer_grpc=self.PREFER_GRPC,
  63. api_key=self.QDRANT_API_KEY,
  64. timeout=self.QDRANT_TIMEOUT,
  65. )
  66. if self.PREFER_GRPC
  67. else Qclient(
  68. url=self.QDRANT_URI,
  69. api_key=self.QDRANT_API_KEY,
  70. timeout=self.QDRANT_TIMEOUT,
  71. )
  72. )
  73. # Main collection types for multi-tenancy
  74. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  75. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  76. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  77. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  78. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  79. def _result_to_get_result(self, points) -> GetResult:
  80. ids, documents, metadatas = [], [], []
  81. for point in points:
  82. payload = point.payload
  83. ids.append(point.id)
  84. documents.append(payload["text"])
  85. metadatas.append(payload["metadata"])
  86. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  87. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  88. """
  89. Maps the traditional collection name to multi-tenant collection and tenant ID.
  90. Returns:
  91. tuple: (collection_name, tenant_id)
  92. """
  93. # Check for user memory collections
  94. tenant_id = collection_name
  95. if collection_name.startswith("user-memory-"):
  96. return self.MEMORY_COLLECTION, tenant_id
  97. # Check for file collections
  98. elif collection_name.startswith("file-"):
  99. return self.FILE_COLLECTION, tenant_id
  100. # Check for web search collections
  101. elif collection_name.startswith("web-search-"):
  102. return self.WEB_SEARCH_COLLECTION, tenant_id
  103. # Handle hash-based collections (YouTube and web URLs)
  104. elif len(collection_name) == 63 and all(
  105. c in "0123456789abcdef" for c in collection_name
  106. ):
  107. return self.HASH_BASED_COLLECTION, tenant_id
  108. else:
  109. return self.KNOWLEDGE_COLLECTION, tenant_id
  110. def _create_multi_tenant_collection(
  111. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  112. ):
  113. """
  114. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
  115. """
  116. self.client.create_collection(
  117. collection_name=mt_collection_name,
  118. vectors_config=models.VectorParams(
  119. size=dimension,
  120. distance=models.Distance.COSINE,
  121. on_disk=self.QDRANT_ON_DISK,
  122. ),
  123. # Disable global index building due to multitenancy
  124. # For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
  125. hnsw_config=models.HnswConfigDiff(
  126. payload_m=self.QDRANT_HNSW_M,
  127. m=0,
  128. ),
  129. )
  130. log.info(
  131. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  132. )
  133. self.client.create_payload_index(
  134. collection_name=mt_collection_name,
  135. field_name=TENANT_ID_FIELD,
  136. field_schema=models.KeywordIndexParams(
  137. type=models.KeywordIndexType.KEYWORD,
  138. is_tenant=True,
  139. on_disk=self.QDRANT_ON_DISK,
  140. ),
  141. )
  142. for field in ("metadata.hash", "metadata.file_id"):
  143. self.client.create_payload_index(
  144. collection_name=mt_collection_name,
  145. field_name=field,
  146. field_schema=models.KeywordIndexParams(
  147. type=models.KeywordIndexType.KEYWORD,
  148. on_disk=self.QDRANT_ON_DISK,
  149. ),
  150. )
  151. def _create_points(
  152. self, items: List[VectorItem], tenant_id: str
  153. ) -> List[PointStruct]:
  154. """
  155. Create point structs from vector items with tenant ID.
  156. """
  157. return [
  158. PointStruct(
  159. id=item["id"],
  160. vector=item["vector"],
  161. payload={
  162. "text": item["text"],
  163. "metadata": item["metadata"],
  164. TENANT_ID_FIELD: tenant_id,
  165. },
  166. )
  167. for item in items
  168. ]
  169. def _ensure_collection(
  170. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  171. ):
  172. """
  173. Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
  174. """
  175. if not self.client.collection_exists(collection_name=mt_collection_name):
  176. self._create_multi_tenant_collection(mt_collection_name, dimension)
  177. def has_collection(self, collection_name: str) -> bool:
  178. """
  179. Check if a logical collection exists by checking for any points with the tenant ID.
  180. """
  181. if not self.client:
  182. return False
  183. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  184. if not self.client.collection_exists(collection_name=mt_collection):
  185. return False
  186. tenant_filter = _tenant_filter(tenant_id)
  187. count_result = self.client.count(
  188. collection_name=mt_collection,
  189. count_filter=models.Filter(must=[tenant_filter]),
  190. )
  191. return count_result.count > 0
  192. def delete(
  193. self,
  194. collection_name: str,
  195. ids: Optional[List[str]] = None,
  196. filter: Optional[Dict[str, Any]] = None,
  197. ):
  198. """
  199. Delete vectors by ID or filter from a collection with tenant isolation.
  200. """
  201. if not self.client:
  202. return None
  203. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  204. if not self.client.collection_exists(collection_name=mt_collection):
  205. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  206. return None
  207. must_conditions = [_tenant_filter(tenant_id)]
  208. should_conditions = []
  209. if ids:
  210. should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
  211. elif filter:
  212. must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
  213. return self.client.delete(
  214. collection_name=mt_collection,
  215. points_selector=models.FilterSelector(
  216. filter=models.Filter(must=must_conditions, should=should_conditions)
  217. ),
  218. )
  219. def search(
  220. self, collection_name: str, vectors: List[List[float | int]], limit: int
  221. ) -> Optional[SearchResult]:
  222. """
  223. Search for the nearest neighbor items based on the vectors with tenant isolation.
  224. """
  225. if not self.client or not vectors:
  226. return None
  227. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  228. if not self.client.collection_exists(collection_name=mt_collection):
  229. log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
  230. return None
  231. tenant_filter = _tenant_filter(tenant_id)
  232. query_response = self.client.query_points(
  233. collection_name=mt_collection,
  234. query=vectors[0],
  235. limit=limit,
  236. query_filter=models.Filter(must=[tenant_filter]),
  237. )
  238. get_result = self._result_to_get_result(query_response.points)
  239. return SearchResult(
  240. ids=get_result.ids,
  241. documents=get_result.documents,
  242. metadatas=get_result.metadatas,
  243. distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
  244. )
  245. def query(
  246. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  247. ):
  248. """
  249. Query points with filters and tenant isolation.
  250. """
  251. if not self.client:
  252. return None
  253. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  254. if not self.client.collection_exists(collection_name=mt_collection):
  255. log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
  256. return None
  257. if limit is None:
  258. limit = NO_LIMIT
  259. tenant_filter = _tenant_filter(tenant_id)
  260. field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
  261. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  262. points = self.client.scroll(
  263. collection_name=mt_collection,
  264. scroll_filter=combined_filter,
  265. limit=limit,
  266. )
  267. return self._result_to_get_result(points[0])
  268. def get(self, collection_name: str) -> Optional[GetResult]:
  269. """
  270. Get all items in a collection with tenant isolation.
  271. """
  272. if not self.client:
  273. return None
  274. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  275. if not self.client.collection_exists(collection_name=mt_collection):
  276. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  277. return None
  278. tenant_filter = _tenant_filter(tenant_id)
  279. points = self.client.scroll(
  280. collection_name=mt_collection,
  281. scroll_filter=models.Filter(must=[tenant_filter]),
  282. limit=NO_LIMIT,
  283. )
  284. return self._result_to_get_result(points[0])
  285. def upsert(self, collection_name: str, items: List[VectorItem]):
  286. """
  287. Upsert items with tenant ID.
  288. """
  289. if not self.client or not items:
  290. return None
  291. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  292. dimension = len(items[0]["vector"])
  293. self._ensure_collection(mt_collection, dimension)
  294. points = self._create_points(items, tenant_id)
  295. self.client.upload_points(mt_collection, points)
  296. return None
  297. def insert(self, collection_name: str, items: List[VectorItem]):
  298. """
  299. Insert items with tenant ID.
  300. """
  301. return self.upsert(collection_name, items)
  302. def reset(self):
  303. """
  304. Reset the database by deleting all collections.
  305. """
  306. if not self.client:
  307. return None
  308. for collection in self.client.get_collections().collections:
  309. if collection.name.startswith(self.collection_prefix):
  310. self.client.delete_collection(collection_name=collection.name)
  311. def delete_collection(self, collection_name: str):
  312. """
  313. Delete a collection.
  314. """
  315. if not self.client:
  316. return None
  317. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  318. if not self.client.collection_exists(collection_name=mt_collection):
  319. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  320. return None
  321. self.client.delete(
  322. collection_name=mt_collection,
  323. points_selector=models.FilterSelector(
  324. filter=models.Filter(must=[_tenant_filter(tenant_id)])
  325. ),
  326. )