qdrant_multitenancy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. QDRANT_COLLECTION_PREFIX,
  12. )
  13. from open_webui.env import SRC_LOG_LEVELS
  14. from open_webui.retrieval.vector.main import (
  15. GetResult,
  16. SearchResult,
  17. VectorDBBase,
  18. VectorItem,
  19. )
  20. from qdrant_client import QdrantClient as Qclient
  21. from qdrant_client.http.exceptions import UnexpectedResponse
  22. from qdrant_client.http.models import PointStruct
  23. from qdrant_client.models import models
  24. NO_LIMIT = 999999999
  25. TENANT_ID_FIELD = "tenant_id"
  26. DEFAULT_DIMENSION = 384
  27. log = logging.getLogger(__name__)
  28. log.setLevel(SRC_LOG_LEVELS["RAG"])
  29. def _tenant_filter(tenant_id: str) -> models.FieldCondition:
  30. return models.FieldCondition(
  31. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  32. )
  33. def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
  34. return models.FieldCondition(
  35. key=f"metadata.{key}", match=models.MatchValue(value=value)
  36. )
  37. class QdrantClient(VectorDBBase):
  38. def __init__(self):
  39. self.collection_prefix = QDRANT_COLLECTION_PREFIX
  40. self.QDRANT_URI = QDRANT_URI
  41. self.QDRANT_API_KEY = QDRANT_API_KEY
  42. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  43. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  44. self.GRPC_PORT = QDRANT_GRPC_PORT
  45. if not self.QDRANT_URI:
  46. self.client = None
  47. return
  48. # Unified handling for either scheme
  49. parsed = urlparse(self.QDRANT_URI)
  50. host = parsed.hostname or self.QDRANT_URI
  51. http_port = parsed.port or 6333 # default REST port
  52. self.client = (
  53. Qclient(
  54. host=host,
  55. port=http_port,
  56. grpc_port=self.GRPC_PORT,
  57. prefer_grpc=self.PREFER_GRPC,
  58. api_key=self.QDRANT_API_KEY,
  59. )
  60. if self.PREFER_GRPC
  61. else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
  62. )
  63. # Main collection types for multi-tenancy
  64. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  65. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  66. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  67. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  68. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  69. def _result_to_get_result(self, points) -> GetResult:
  70. ids, documents, metadatas = [], [], []
  71. for point in points:
  72. payload = point.payload
  73. ids.append(point.id)
  74. documents.append(payload["text"])
  75. metadatas.append(payload["metadata"])
  76. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  77. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  78. """
  79. Maps the traditional collection name to multi-tenant collection and tenant ID.
  80. Returns:
  81. tuple: (collection_name, tenant_id)
  82. """
  83. # Check for user memory collections
  84. tenant_id = collection_name
  85. if collection_name.startswith("user-memory-"):
  86. return self.MEMORY_COLLECTION, tenant_id
  87. # Check for file collections
  88. elif collection_name.startswith("file-"):
  89. return self.FILE_COLLECTION, tenant_id
  90. # Check for web search collections
  91. elif collection_name.startswith("web-search-"):
  92. return self.WEB_SEARCH_COLLECTION, tenant_id
  93. # Handle hash-based collections (YouTube and web URLs)
  94. elif len(collection_name) == 63 and all(
  95. c in "0123456789abcdef" for c in collection_name
  96. ):
  97. return self.HASH_BASED_COLLECTION, tenant_id
  98. else:
  99. return self.KNOWLEDGE_COLLECTION, tenant_id
  100. def _create_multi_tenant_collection(
  101. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  102. ):
  103. """
  104. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
  105. """
  106. self.client.create_collection(
  107. collection_name=mt_collection_name,
  108. vectors_config=models.VectorParams(
  109. size=dimension,
  110. distance=models.Distance.COSINE,
  111. on_disk=self.QDRANT_ON_DISK,
  112. ),
  113. )
  114. log.info(
  115. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  116. )
  117. self.client.create_payload_index(
  118. collection_name=mt_collection_name,
  119. field_name=TENANT_ID_FIELD,
  120. field_schema=models.KeywordIndexParams(
  121. type=models.KeywordIndexType.KEYWORD,
  122. is_tenant=True,
  123. on_disk=self.QDRANT_ON_DISK,
  124. ),
  125. )
  126. for field in ("metadata.hash", "metadata.file_id"):
  127. self.client.create_payload_index(
  128. collection_name=mt_collection_name,
  129. field_name=field,
  130. field_schema=models.KeywordIndexParams(
  131. type=models.KeywordIndexType.KEYWORD,
  132. on_disk=self.QDRANT_ON_DISK,
  133. ),
  134. )
  135. def _create_points(
  136. self, items: List[VectorItem], tenant_id: str
  137. ) -> List[PointStruct]:
  138. """
  139. Create point structs from vector items with tenant ID.
  140. """
  141. return [
  142. PointStruct(
  143. id=item["id"],
  144. vector=item["vector"],
  145. payload={
  146. "text": item["text"],
  147. "metadata": item["metadata"],
  148. TENANT_ID_FIELD: tenant_id,
  149. },
  150. )
  151. for item in items
  152. ]
  153. def _ensure_collection(
  154. self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
  155. ):
  156. """
  157. Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
  158. """
  159. if not self.client.collection_exists(collection_name=mt_collection_name):
  160. self._create_multi_tenant_collection(mt_collection_name, dimension)
  161. def has_collection(self, collection_name: str) -> bool:
  162. """
  163. Check if a logical collection exists by checking for any points with the tenant ID.
  164. """
  165. if not self.client:
  166. return False
  167. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  168. if not self.client.collection_exists(collection_name=mt_collection):
  169. return False
  170. tenant_filter = _tenant_filter(tenant_id)
  171. count_result = self.client.count(
  172. collection_name=mt_collection,
  173. count_filter=models.Filter(must=[tenant_filter]),
  174. )
  175. return count_result.count > 0
  176. def delete(
  177. self,
  178. collection_name: str,
  179. ids: Optional[List[str]] = None,
  180. filter: Optional[Dict[str, Any]] = None,
  181. ):
  182. """
  183. Delete vectors by ID or filter from a collection with tenant isolation.
  184. """
  185. if not self.client:
  186. return None
  187. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  188. if not self.client.collection_exists(collection_name=mt_collection):
  189. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  190. return None
  191. must_conditions = [_tenant_filter(tenant_id)]
  192. should_conditions = []
  193. if ids:
  194. should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
  195. elif filter:
  196. must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
  197. try:
  198. return self.client.delete(
  199. collection_name=mt_collection,
  200. points_selector=models.FilterSelector(
  201. filter=models.Filter(must=must_conditions, should=should_conditions)
  202. ),
  203. )
  204. except Exception as e:
  205. log.warning(f"Error deleting from collection {mt_collection}: {e}")
  206. return None
  207. def search(
  208. self, collection_name: str, vectors: List[List[float | int]], limit: int
  209. ) -> Optional[SearchResult]:
  210. """
  211. Search for the nearest neighbor items based on the vectors with tenant isolation.
  212. """
  213. if not self.client or not vectors:
  214. return None
  215. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  216. if not self.client.collection_exists(collection_name=mt_collection):
  217. log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
  218. return None
  219. tenant_filter = _tenant_filter(tenant_id)
  220. query_response = self.client.query_points(
  221. collection_name=mt_collection,
  222. query=vectors[0],
  223. limit=limit,
  224. query_filter=models.Filter(must=[tenant_filter]),
  225. )
  226. get_result = self._result_to_get_result(query_response.points)
  227. return SearchResult(
  228. ids=get_result.ids,
  229. documents=get_result.documents,
  230. metadatas=get_result.metadatas,
  231. distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
  232. )
  233. def query(
  234. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  235. ):
  236. """
  237. Query points with filters and tenant isolation.
  238. """
  239. if not self.client:
  240. return None
  241. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  242. if not self.client.collection_exists(collection_name=mt_collection):
  243. log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
  244. return None
  245. if limit is None:
  246. limit = NO_LIMIT
  247. tenant_filter = _tenant_filter(tenant_id)
  248. field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
  249. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  250. try:
  251. points = self.client.query_points(
  252. collection_name=mt_collection,
  253. query_filter=combined_filter,
  254. limit=limit,
  255. )
  256. return self._result_to_get_result(points.points)
  257. except Exception as e:
  258. log.exception(f"Error querying collection '{collection_name}': {e}")
  259. return None
  260. def get(self, collection_name: str) -> Optional[GetResult]:
  261. """
  262. Get all items in a collection with tenant isolation.
  263. """
  264. if not self.client:
  265. return None
  266. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  267. if not self.client.collection_exists(collection_name=mt_collection):
  268. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  269. return None
  270. tenant_filter = _tenant_filter(tenant_id)
  271. try:
  272. points = self.client.query_points(
  273. collection_name=mt_collection,
  274. query_filter=models.Filter(must=[tenant_filter]),
  275. limit=NO_LIMIT,
  276. )
  277. return self._result_to_get_result(points.points)
  278. except Exception as e:
  279. log.exception(f"Error getting collection '{collection_name}': {e}")
  280. return None
  281. def upsert(self, collection_name: str, items: List[VectorItem]):
  282. """
  283. Upsert items with tenant ID.
  284. """
  285. if not self.client or not items:
  286. return None
  287. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  288. dimension = len(items[0]["vector"])
  289. self._ensure_collection(mt_collection, dimension)
  290. points = self._create_points(items, tenant_id)
  291. self.client.upload_points(mt_collection, points)
  292. return None
  293. def insert(self, collection_name: str, items: List[VectorItem]):
  294. """
  295. Insert items with tenant ID.
  296. """
  297. return self.upsert(collection_name, items)
  298. def reset(self):
  299. """
  300. Reset the database by deleting all collections.
  301. """
  302. if not self.client:
  303. return None
  304. for collection in self.client.get_collections().collections:
  305. if collection.name.startswith(self.collection_prefix):
  306. self.client.delete_collection(collection_name=collection.name)
  307. def delete_collection(self, collection_name: str):
  308. """
  309. Delete a collection.
  310. """
  311. if not self.client:
  312. return None
  313. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  314. if not self.client.collection_exists(collection_name=mt_collection):
  315. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  316. return None
  317. self.client.delete(
  318. collection_name=mt_collection,
  319. points_selector=models.FilterSelector(
  320. filter=models.Filter(must=[_tenant_filter(tenant_id)])
  321. ),
  322. )