qdrant_multitenancy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import logging
  2. from typing import Optional, Tuple
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. )
  12. from open_webui.env import SRC_LOG_LEVELS
  13. from open_webui.retrieval.vector.main import (
  14. GetResult,
  15. SearchResult,
  16. VectorDBBase,
  17. VectorItem,
  18. )
  19. from qdrant_client import QdrantClient as Qclient
  20. from qdrant_client.http.exceptions import UnexpectedResponse
  21. from qdrant_client.http.models import PointStruct
  22. from qdrant_client.models import models
  23. NO_LIMIT = 999999999
  24. TENANT_ID_FIELD = "tenant_id"
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. class QdrantClient(VectorDBBase):
  28. def __init__(self):
  29. self.collection_prefix = "open-webui"
  30. self.QDRANT_URI = QDRANT_URI
  31. self.QDRANT_API_KEY = QDRANT_API_KEY
  32. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  33. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  34. self.GRPC_PORT = QDRANT_GRPC_PORT
  35. if not self.QDRANT_URI:
  36. self.client = None
  37. return
  38. # Unified handling for either scheme
  39. parsed = urlparse(self.QDRANT_URI)
  40. host = parsed.hostname or self.QDRANT_URI
  41. http_port = parsed.port or 6333 # default REST port
  42. if self.PREFER_GRPC:
  43. self.client = Qclient(
  44. host=host,
  45. port=http_port,
  46. grpc_port=self.GRPC_PORT,
  47. prefer_grpc=self.PREFER_GRPC,
  48. api_key=self.QDRANT_API_KEY,
  49. )
  50. else:
  51. self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
  52. # Main collection types for multi-tenancy
  53. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  54. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  55. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  56. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  57. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  58. def _result_to_get_result(self, points) -> GetResult:
  59. ids = []
  60. documents = []
  61. metadatas = []
  62. for point in points:
  63. payload = point.payload
  64. ids.append(point.id)
  65. documents.append(payload["text"])
  66. metadatas.append(payload["metadata"])
  67. return GetResult(
  68. **{
  69. "ids": [ids],
  70. "documents": [documents],
  71. "metadatas": [metadatas],
  72. }
  73. )
  74. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  75. """
  76. Maps the traditional collection name to multi-tenant collection and tenant ID.
  77. Returns:
  78. tuple: (collection_name, tenant_id)
  79. """
  80. # Check for user memory collections
  81. tenant_id = collection_name
  82. if collection_name.startswith("user-memory-"):
  83. return self.MEMORY_COLLECTION, tenant_id
  84. # Check for file collections
  85. elif collection_name.startswith("file-"):
  86. return self.FILE_COLLECTION, tenant_id
  87. # Check for web search collections
  88. elif collection_name.startswith("web-search-"):
  89. return self.WEB_SEARCH_COLLECTION, tenant_id
  90. # Handle hash-based collections (YouTube and web URLs)
  91. elif len(collection_name) == 63 and all(
  92. c in "0123456789abcdef" for c in collection_name
  93. ):
  94. return self.HASH_BASED_COLLECTION, tenant_id
  95. else:
  96. return self.KNOWLEDGE_COLLECTION, tenant_id
  97. def _create_multi_tenant_collection(
  98. self,
  99. mt_collection_name: str,
  100. dimension: int = 384,
  101. ):
  102. """
  103. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
  104. """
  105. self.client.create_collection(
  106. collection_name=mt_collection_name,
  107. vectors_config=models.VectorParams(
  108. size=dimension,
  109. distance=models.Distance.COSINE,
  110. on_disk=self.QDRANT_ON_DISK,
  111. ),
  112. )
  113. log.info(
  114. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  115. )
  116. self.client.create_payload_index(
  117. collection_name=mt_collection_name,
  118. field_name=TENANT_ID_FIELD,
  119. field_schema=models.KeywordIndexParams(
  120. type=models.KeywordIndexType.KEYWORD,
  121. is_tenant=True,
  122. on_disk=self.QDRANT_ON_DISK,
  123. ),
  124. )
  125. def _create_points(self, items: list[VectorItem], tenant_id: str):
  126. """
  127. Create point structs from vector items with tenant ID.
  128. """
  129. return [
  130. PointStruct(
  131. id=item["id"],
  132. vector=item["vector"],
  133. payload={
  134. "text": item["text"],
  135. "metadata": item["metadata"],
  136. TENANT_ID_FIELD: tenant_id,
  137. },
  138. )
  139. for item in items
  140. ]
  141. def _ensure_collection(
  142. self,
  143. mt_collection_name: str,
  144. dimension: int = 384,
  145. ):
  146. """
  147. Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
  148. """
  149. if self.client.collection_exists(collection_name=mt_collection_name):
  150. return
  151. self._create_multi_tenant_collection(mt_collection_name, dimension)
  152. def has_collection(self, collection_name: str) -> bool:
  153. """
  154. Check if a logical collection exists by checking for any points with the tenant ID.
  155. """
  156. if not self.client:
  157. return False
  158. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  159. if not self.client.collection_exists(collection_name=mt_collection):
  160. return False
  161. tenant_filter = models.FieldCondition(
  162. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  163. )
  164. count_result = self.client.count(
  165. collection_name=mt_collection,
  166. count_filter=models.Filter(must=[tenant_filter]),
  167. )
  168. return count_result.count > 0
  169. def delete(
  170. self,
  171. collection_name: str,
  172. ids: Optional[list[str]] = None,
  173. filter: Optional[dict] = None,
  174. ):
  175. """
  176. Delete vectors by ID or filter from a collection with tenant isolation.
  177. """
  178. if not self.client:
  179. return None
  180. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  181. if not self.client.collection_exists(collection_name=mt_collection):
  182. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  183. return None
  184. tenant_filter = models.FieldCondition(
  185. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  186. )
  187. must_conditions = [tenant_filter]
  188. should_conditions = []
  189. if ids:
  190. for id_value in ids:
  191. should_conditions.append(
  192. models.FieldCondition(
  193. key="metadata.id",
  194. match=models.MatchValue(value=id_value),
  195. ),
  196. )
  197. elif filter:
  198. for key, value in filter.items():
  199. must_conditions.append(
  200. models.FieldCondition(
  201. key=f"metadata.{key}",
  202. match=models.MatchValue(value=value),
  203. ),
  204. )
  205. try:
  206. update_result = self.client.delete(
  207. collection_name=mt_collection,
  208. points_selector=models.FilterSelector(
  209. filter=models.Filter(must=must_conditions, should=should_conditions)
  210. ),
  211. )
  212. return update_result
  213. except Exception as e:
  214. log.warning(f"Error deleting from collection {mt_collection}: {e}")
  215. return None
  216. def search(
  217. self, collection_name: str, vectors: list[list[float | int]], limit: int
  218. ) -> Optional[SearchResult]:
  219. """
  220. Search for the nearest neighbor items based on the vectors with tenant isolation.
  221. """
  222. if not self.client:
  223. return None
  224. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  225. if not self.client.collection_exists(collection_name=mt_collection):
  226. log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
  227. return None
  228. dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
  229. try:
  230. tenant_filter = models.FieldCondition(
  231. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  232. )
  233. collection_dim = self.client.get_collection(
  234. mt_collection
  235. ).config.params.vectors.size
  236. if collection_dim != dimension:
  237. if collection_dim < dimension:
  238. vectors = [vector[:collection_dim] for vector in vectors]
  239. else:
  240. vectors = [
  241. vector + [0] * (collection_dim - dimension)
  242. for vector in vectors
  243. ]
  244. prefetch_query = models.Prefetch(
  245. filter=models.Filter(must=[tenant_filter]),
  246. limit=NO_LIMIT,
  247. )
  248. query_response = self.client.query_points(
  249. collection_name=mt_collection,
  250. query=vectors[0],
  251. prefetch=prefetch_query,
  252. limit=limit,
  253. )
  254. get_result = self._result_to_get_result(query_response.points)
  255. return SearchResult(
  256. ids=get_result.ids,
  257. documents=get_result.documents,
  258. metadatas=get_result.metadatas,
  259. distances=[
  260. [(point.score + 1.0) / 2.0 for point in query_response.points]
  261. ],
  262. )
  263. except Exception as e:
  264. log.exception(f"Error searching collection '{collection_name}': {e}")
  265. return None
  266. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  267. """
  268. Query points with filters and tenant isolation.
  269. """
  270. if not self.client:
  271. return None
  272. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  273. if not self.client.collection_exists(collection_name=mt_collection):
  274. log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
  275. return None
  276. if limit is None:
  277. limit = NO_LIMIT
  278. tenant_filter = models.FieldCondition(
  279. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  280. )
  281. field_conditions = []
  282. for key, value in filter.items():
  283. field_conditions.append(
  284. models.FieldCondition(
  285. key=f"metadata.{key}", match=models.MatchValue(value=value)
  286. )
  287. )
  288. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  289. try:
  290. points = self.client.query_points(
  291. collection_name=mt_collection,
  292. query_filter=combined_filter,
  293. limit=limit,
  294. )
  295. return self._result_to_get_result(points.points)
  296. except Exception as e:
  297. log.exception(f"Error querying collection '{collection_name}': {e}")
  298. return None
  299. def get(self, collection_name: str) -> Optional[GetResult]:
  300. """
  301. Get all items in a collection with tenant isolation.
  302. """
  303. if not self.client:
  304. return None
  305. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  306. if not self.client.collection_exists(collection_name=mt_collection):
  307. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  308. return None
  309. tenant_filter = models.FieldCondition(
  310. key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
  311. )
  312. try:
  313. points = self.client.query_points(
  314. collection_name=mt_collection,
  315. query_filter=models.Filter(must=[tenant_filter]),
  316. limit=NO_LIMIT,
  317. )
  318. return self._result_to_get_result(points.points)
  319. except Exception as e:
  320. log.exception(f"Error getting collection '{collection_name}': {e}")
  321. return None
  322. def upsert(self, collection_name: str, items: list[VectorItem]):
  323. """
  324. Upsert items with tenant ID.
  325. """
  326. if not self.client or not items:
  327. return None
  328. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  329. dimension = len(items[0]["vector"]) if items else None
  330. self._ensure_collection(mt_collection, dimension)
  331. points = self._create_points(items, tenant_id)
  332. self.client.upload_points(mt_collection, points)
  333. return None
  334. def insert(self, collection_name: str, items: list[VectorItem]):
  335. """
  336. Insert items with tenant ID.
  337. """
  338. return self.upsert(collection_name, items)
  339. def reset(self):
  340. """
  341. Reset the database by deleting all collections.
  342. """
  343. if not self.client:
  344. return None
  345. collection_names = self.client.get_collections().collections
  346. for collection_name in collection_names:
  347. if collection_name.name.startswith(self.collection_prefix):
  348. self.client.delete_collection(collection_name=collection_name.name)
  349. def delete_collection(self, collection_name: str):
  350. """
  351. Delete a collection.
  352. """
  353. if not self.client:
  354. return None
  355. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  356. if not self.client.collection_exists(collection_name=mt_collection):
  357. log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
  358. return None
  359. self.client.delete(
  360. collection_name=mt_collection,
  361. points_selector=models.FilterSelector(
  362. filter=models.Filter(
  363. must=[
  364. models.FieldCondition(
  365. key=TENANT_ID_FIELD,
  366. match=models.MatchValue(value=tenant_id),
  367. )
  368. ]
  369. )
  370. ),
  371. )