qdrant_multitenancy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. QDRANT_COLLECTION_PREFIX,
  12. )
  13. from open_webui.env import SRC_LOG_LEVELS
  14. from open_webui.retrieval.vector.main import (
  15. GetResult,
  16. SearchResult,
  17. VectorDBBase,
  18. VectorItem,
  19. )
  20. from qdrant_client import QdrantClient as Qclient
  21. from qdrant_client.http.exceptions import UnexpectedResponse
  22. from qdrant_client.http.models import PointStruct
  23. from qdrant_client.models import models
  24. NO_LIMIT = 999999999
  25. TENANT_ID_FIELD = "tenant_id"
  26. DEFAULT_DIMENSION = 384
  27. log = logging.getLogger(__name__)
  28. log.setLevel(SRC_LOG_LEVELS["RAG"])
  29. def _tenant_filter(tenant_id: str) -> models.FieldCondition:
  30. return models.FieldCondition(
  31. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  32. )
  33. def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
  34. return models.FieldCondition(
  35. key=f"metadata.{key}", match=models.MatchValue(value=value)
  36. )
  37. class QdrantClient(VectorDBBase):
  38. def __init__(self):
  39. self.collection_prefix = QDRANT_COLLECTION_PREFIX
  40. self.QDRANT_URI = QDRANT_URI
  41. self.QDRANT_API_KEY = QDRANT_API_KEY
  42. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  43. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  44. self.GRPC_PORT = QDRANT_GRPC_PORT
  45. if not self.QDRANT_URI:
  46. raise ValueError(
  47. "QDRANT_URI is not set. Please configure it in the environment variables."
  48. )
  49. # Unified handling for either scheme
  50. parsed = urlparse(self.QDRANT_URI)
  51. host = parsed.hostname or self.QDRANT_URI
  52. http_port = parsed.port or 6333 # default REST port
  53. self.client = (
  54. Qclient(
  55. host=host,
  56. port=http_port,
  57. grpc_port=self.GRPC_PORT,
  58. prefer_grpc=self.PREFER_GRPC,
  59. api_key=self.QDRANT_API_KEY,
  60. )
  61. if self.PREFER_GRPC
  62. else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
  63. )
  64. # Main collection types for multi-tenancy
  65. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  66. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  67. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  68. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  69. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  70. def _result_to_get_result(self, points) -> GetResult:
  71. ids, documents, metadatas = [], [], []
  72. for point in points:
  73. payload = point.payload
  74. ids.append(point.id)
  75. documents.append(payload["text"])
  76. metadatas.append(payload["metadata"])
  77. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  78. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  79. """
  80. Maps the traditional collection name to multi-tenant collection and tenant ID.
  81. Returns:
  82. tuple: (collection_name, tenant_id)
  83. """
  84. # Check for user memory collections
  85. tenant_id = collection_name
  86. if collection_name.startswith("user-memory-"):
  87. return self.MEMORY_COLLECTION, tenant_id
  88. # Check for file collections
  89. elif collection_name.startswith("file-"):
  90. return self.FILE_COLLECTION, tenant_id
  91. # Check for web search collections
  92. elif collection_name.startswith("web-search-"):
  93. return self.WEB_SEARCH_COLLECTION, tenant_id
  94. # Handle hash-based collections (YouTube and web URLs)
  95. elif len(collection_name) == 63 and all(
  96. c in "0123456789abcdef" for c in collection_name
  97. ):
  98. return self.HASH_BASED_COLLECTION, tenant_id
  99. else:
  100. return self.KNOWLEDGE_COLLECTION, tenant_id
  101. def _create_multi_tenant_collection(
  102. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  103. ):
  104. """
  105. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
  106. """
  107. self.client.create_collection(
  108. collection_name=mt_collection_name,
  109. vectors_config=models.VectorParams(
  110. size=dimension,
  111. distance=models.Distance.COSINE,
  112. on_disk=self.QDRANT_ON_DISK,
  113. ),
  114. )
  115. log.info(
  116. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  117. )
  118. self.client.create_payload_index(
  119. collection_name=mt_collection_name,
  120. field_name=TENANT_ID_FIELD,
  121. field_schema=models.KeywordIndexParams(
  122. type=models.KeywordIndexType.KEYWORD,
  123. is_tenant=True,
  124. on_disk=self.QDRANT_ON_DISK,
  125. ),
  126. )
  127. for field in ("metadata.hash", "metadata.file_id"):
  128. self.client.create_payload_index(
  129. collection_name=mt_collection_name,
  130. field_name=field,
  131. field_schema=models.KeywordIndexParams(
  132. type=models.KeywordIndexType.KEYWORD,
  133. on_disk=self.QDRANT_ON_DISK,
  134. ),
  135. )
  136. def _create_points(
  137. self, items: List[VectorItem], tenant_id: str
  138. ) -> List[PointStruct]:
  139. """
  140. Create point structs from vector items with tenant ID.
  141. """
  142. return [
  143. PointStruct(
  144. id=item["id"],
  145. vector=item["vector"],
  146. payload={
  147. "text": item["text"],
  148. "metadata": item["metadata"],
  149. TENANT_ID_FIELD: tenant_id,
  150. },
  151. )
  152. for item in items
  153. ]
  154. def _ensure_collection(
  155. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  156. ):
  157. """
  158. Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
  159. """
  160. if not self.client.collection_exists(collection_name=mt_collection_name):
  161. self._create_multi_tenant_collection(mt_collection_name, dimension)
  162. def has_collection(self, collection_name: str) -> bool:
  163. """
  164. Check if a logical collection exists by checking for any points with the tenant ID.
  165. """
  166. if not self.client:
  167. return False
  168. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  169. if not self.client.collection_exists(collection_name=mt_collection):
  170. return False
  171. tenant_filter = _tenant_filter(tenant_id)
  172. count_result = self.client.count(
  173. collection_name=mt_collection,
  174. count_filter=models.Filter(must=[tenant_filter]),
  175. )
  176. return count_result.count > 0
  177. def delete(
  178. self,
  179. collection_name: str,
  180. ids: Optional[List[str]] = None,
  181. filter: Optional[Dict[str, Any]] = None,
  182. ):
  183. """
  184. Delete vectors by ID or filter from a collection with tenant isolation.
  185. """
  186. if not self.client:
  187. return None
  188. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  189. if not self.client.collection_exists(collection_name=mt_collection):
  190. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  191. return None
  192. must_conditions = [_tenant_filter(tenant_id)]
  193. should_conditions = []
  194. if ids:
  195. should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
  196. elif filter:
  197. must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
  198. return self.client.delete(
  199. collection_name=mt_collection,
  200. points_selector=models.FilterSelector(
  201. filter=models.Filter(must=must_conditions, should=should_conditions)
  202. ),
  203. )
  204. def search(
  205. self, collection_name: str, vectors: List[List[float | int]], limit: int
  206. ) -> Optional[SearchResult]:
  207. """
  208. Search for the nearest neighbor items based on the vectors with tenant isolation.
  209. """
  210. if not self.client or not vectors:
  211. return None
  212. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  213. if not self.client.collection_exists(collection_name=mt_collection):
  214. log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
  215. return None
  216. tenant_filter = _tenant_filter(tenant_id)
  217. query_response = self.client.query_points(
  218. collection_name=mt_collection,
  219. query=vectors[0],
  220. limit=limit,
  221. query_filter=models.Filter(must=[tenant_filter]),
  222. )
  223. get_result = self._result_to_get_result(query_response.points)
  224. return SearchResult(
  225. ids=get_result.ids,
  226. documents=get_result.documents,
  227. metadatas=get_result.metadatas,
  228. distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
  229. )
  230. def query(
  231. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  232. ):
  233. """
  234. Query points with filters and tenant isolation.
  235. """
  236. if not self.client:
  237. return None
  238. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  239. if not self.client.collection_exists(collection_name=mt_collection):
  240. log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
  241. return None
  242. if limit is None:
  243. limit = NO_LIMIT
  244. tenant_filter = _tenant_filter(tenant_id)
  245. field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
  246. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  247. points = self.client.query_points(
  248. collection_name=mt_collection,
  249. query_filter=combined_filter,
  250. limit=limit,
  251. )
  252. return self._result_to_get_result(points.points)
  253. def get(self, collection_name: str) -> Optional[GetResult]:
  254. """
  255. Get all items in a collection with tenant isolation.
  256. """
  257. if not self.client:
  258. return None
  259. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  260. if not self.client.collection_exists(collection_name=mt_collection):
  261. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  262. return None
  263. tenant_filter = _tenant_filter(tenant_id)
  264. points = self.client.query_points(
  265. collection_name=mt_collection,
  266. query_filter=models.Filter(must=[tenant_filter]),
  267. limit=NO_LIMIT,
  268. )
  269. return self._result_to_get_result(points.points)
  270. def upsert(self, collection_name: str, items: List[VectorItem]):
  271. """
  272. Upsert items with tenant ID.
  273. """
  274. if not self.client or not items:
  275. return None
  276. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  277. dimension = len(items[0]["vector"])
  278. self._ensure_collection(mt_collection, dimension)
  279. points = self._create_points(items, tenant_id)
  280. self.client.upload_points(mt_collection, points)
  281. return None
  282. def insert(self, collection_name: str, items: List[VectorItem]):
  283. """
  284. Insert items with tenant ID.
  285. """
  286. return self.upsert(collection_name, items)
  287. def reset(self):
  288. """
  289. Reset the database by deleting all collections.
  290. """
  291. if not self.client:
  292. return None
  293. for collection in self.client.get_collections().collections:
  294. if collection.name.startswith(self.collection_prefix):
  295. self.client.delete_collection(collection_name=collection.name)
  296. def delete_collection(self, collection_name: str):
  297. """
  298. Delete a collection.
  299. """
  300. if not self.client:
  301. return None
  302. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  303. if not self.client.collection_exists(collection_name=mt_collection):
  304. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  305. return None
  306. self.client.delete(
  307. collection_name=mt_collection,
  308. points_selector=models.FilterSelector(
  309. filter=models.Filter(must=[_tenant_filter(tenant_id)])
  310. ),
  311. )