qdrant_multitenancy.py 26 KB


  1. import logging
  2. from typing import Optional, Tuple
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. )
  12. from open_webui.env import SRC_LOG_LEVELS
  13. from open_webui.retrieval.vector.main import (
  14. GetResult,
  15. SearchResult,
  16. VectorDBBase,
  17. VectorItem,
  18. )
  19. from qdrant_client import QdrantClient as Qclient
  20. from qdrant_client.http.exceptions import UnexpectedResponse
  21. from qdrant_client.http.models import PointStruct
  22. from qdrant_client.models import models
  23. NO_LIMIT = 999999999
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["RAG"])
  26. class QdrantClient(VectorDBBase):
  27. def __init__(self):
  28. self.collection_prefix = "open-webui"
  29. self.QDRANT_URI = QDRANT_URI
  30. self.QDRANT_API_KEY = QDRANT_API_KEY
  31. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  32. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  33. self.GRPC_PORT = QDRANT_GRPC_PORT
  34. if not self.QDRANT_URI:
  35. self.client = None
  36. return
  37. # Unified handling for either scheme
  38. parsed = urlparse(self.QDRANT_URI)
  39. host = parsed.hostname or self.QDRANT_URI
  40. http_port = parsed.port or 6333 # default REST port
  41. if self.PREFER_GRPC:
  42. self.client = Qclient(
  43. host=host,
  44. port=http_port,
  45. grpc_port=self.GRPC_PORT,
  46. prefer_grpc=self.PREFER_GRPC,
  47. api_key=self.QDRANT_API_KEY,
  48. )
  49. else:
  50. self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
  51. # Main collection types for multi-tenancy
  52. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  53. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  54. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  55. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  56. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  57. def _result_to_get_result(self, points) -> GetResult:
  58. ids = []
  59. documents = []
  60. metadatas = []
  61. for point in points:
  62. payload = point.payload
  63. ids.append(point.id)
  64. documents.append(payload["text"])
  65. metadatas.append(payload["metadata"])
  66. return GetResult(
  67. **{
  68. "ids": [ids],
  69. "documents": [documents],
  70. "metadatas": [metadatas],
  71. }
  72. )
  73. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  74. """
  75. Maps the traditional collection name to multi-tenant collection and tenant ID.
  76. Returns:
  77. tuple: (collection_name, tenant_id)
  78. """
  79. # Check for user memory collections
  80. tenant_id = collection_name
  81. if collection_name.startswith("user-memory-"):
  82. return self.MEMORY_COLLECTION, tenant_id
  83. # Check for file collections
  84. elif collection_name.startswith("file-"):
  85. return self.FILE_COLLECTION, tenant_id
  86. # Check for web search collections
  87. elif collection_name.startswith("web-search-"):
  88. return self.WEB_SEARCH_COLLECTION, tenant_id
  89. # Handle hash-based collections (YouTube and web URLs)
  90. elif len(collection_name) == 63 and all(
  91. c in "0123456789abcdef" for c in collection_name
  92. ):
  93. return self.HASH_BASED_COLLECTION, tenant_id
  94. else:
  95. return self.KNOWLEDGE_COLLECTION, tenant_id
  96. def _extract_error_message(self, exception):
  97. """
  98. Extract error message from either HTTP or gRPC exceptions
  99. Returns:
  100. tuple: (status_code, error_message)
  101. """
  102. # Check if it's an HTTP exception
  103. if isinstance(exception, UnexpectedResponse):
  104. try:
  105. error_data = exception.structured()
  106. error_msg = error_data.get("status", {}).get("error", "")
  107. return exception.status_code, error_msg
  108. except Exception as inner_e:
  109. log.error(f"Failed to parse HTTP error: {inner_e}")
  110. return exception.status_code, str(exception)
  111. # Check if it's a gRPC exception
  112. elif isinstance(exception, grpc.RpcError):
  113. # Extract status code from gRPC error
  114. status_code = None
  115. if hasattr(exception, "code") and callable(exception.code):
  116. status_code = exception.code().value[0]
  117. # Extract error message
  118. error_msg = str(exception)
  119. if "details =" in error_msg:
  120. # Parse the details line which contains the actual error message
  121. try:
  122. details_line = [
  123. line.strip()
  124. for line in error_msg.split("\n")
  125. if "details =" in line
  126. ][0]
  127. error_msg = details_line.split("details =")[1].strip(' "')
  128. except (IndexError, AttributeError):
  129. # Fall back to full message if parsing fails
  130. pass
  131. return status_code, error_msg
  132. # For any other type of exception
  133. return None, str(exception)
  134. def _is_collection_not_found_error(self, exception):
  135. """
  136. Check if the exception is due to collection not found, supporting both HTTP and gRPC
  137. """
  138. status_code, error_msg = self._extract_error_message(exception)
  139. # HTTP error (404)
  140. if (
  141. status_code == 404
  142. and "Collection" in error_msg
  143. and "doesn't exist" in error_msg
  144. ):
  145. return True
  146. # gRPC error (NOT_FOUND status)
  147. if (
  148. isinstance(exception, grpc.RpcError)
  149. and exception.code() == grpc.StatusCode.NOT_FOUND
  150. ):
  151. return True
  152. return False
  153. def _is_dimension_mismatch_error(self, exception):
  154. """
  155. Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
  156. """
  157. status_code, error_msg = self._extract_error_message(exception)
  158. # Common patterns in both HTTP and gRPC
  159. return (
  160. "Vector dimension error" in error_msg
  161. or "dimensions mismatch" in error_msg
  162. or "invalid vector size" in error_msg
  163. )
  164. def _create_multi_tenant_collection_if_not_exists(
  165. self, mt_collection_name: str, dimension: int = 384
  166. ):
  167. """
  168. Creates a collection with multi-tenancy configuration if it doesn't exist.
  169. Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
  170. When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
  171. """
  172. try:
  173. # Try to create the collection directly - will fail if it already exists
  174. self.client.create_collection(
  175. collection_name=mt_collection_name,
  176. vectors_config=models.VectorParams(
  177. size=dimension,
  178. distance=models.Distance.COSINE,
  179. on_disk=self.QDRANT_ON_DISK,
  180. ),
  181. hnsw_config=models.HnswConfigDiff(
  182. payload_m=16, # Enable per-tenant indexing
  183. m=0,
  184. on_disk=self.QDRANT_ON_DISK,
  185. ),
  186. )
  187. # Create tenant ID payload index
  188. self.client.create_payload_index(
  189. collection_name=mt_collection_name,
  190. field_name="tenant_id",
  191. field_schema=models.KeywordIndexParams(
  192. type=models.KeywordIndexType.KEYWORD,
  193. is_tenant=True,
  194. on_disk=self.QDRANT_ON_DISK,
  195. ),
  196. wait=True,
  197. )
  198. log.info(
  199. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  200. )
  201. except (UnexpectedResponse, grpc.RpcError) as e:
  202. # Check for the specific error indicating collection already exists
  203. status_code, error_msg = self._extract_error_message(e)
  204. # HTTP status code 409 or gRPC ALREADY_EXISTS
  205. if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
  206. isinstance(e, grpc.RpcError)
  207. and e.code() == grpc.StatusCode.ALREADY_EXISTS
  208. ):
  209. if "already exists" in error_msg:
  210. log.debug(f"Collection {mt_collection_name} already exists")
  211. return
  212. # If it's not an already exists error, re-raise
  213. raise e
  214. except Exception as e:
  215. raise e
  216. def _create_points(self, items: list[VectorItem], tenant_id: str):
  217. """
  218. Create point structs from vector items with tenant ID.
  219. """
  220. return [
  221. PointStruct(
  222. id=item["id"],
  223. vector=item["vector"],
  224. payload={
  225. "text": item["text"],
  226. "metadata": item["metadata"],
  227. "tenant_id": tenant_id,
  228. },
  229. )
  230. for item in items
  231. ]
  232. def has_collection(self, collection_name: str) -> bool:
  233. """
  234. Check if a logical collection exists by checking for any points with the tenant ID.
  235. """
  236. if not self.client:
  237. return False
  238. # Map to multi-tenant collection and tenant ID
  239. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  240. # Create tenant filter
  241. tenant_filter = models.FieldCondition(
  242. key="tenant_id", match=models.MatchValue(value=tenant_id)
  243. )
  244. try:
  245. # Try directly querying - most of the time collection should exist
  246. response = self.client.query_points(
  247. collection_name=mt_collection,
  248. query_filter=models.Filter(must=[tenant_filter]),
  249. limit=1,
  250. )
  251. # Collection exists with this tenant ID if there are points
  252. return len(response.points) > 0
  253. except (UnexpectedResponse, grpc.RpcError) as e:
  254. if self._is_collection_not_found_error(e):
  255. log.debug(f"Collection {mt_collection} doesn't exist")
  256. return False
  257. else:
  258. # For other API errors, log and return False
  259. _, error_msg = self._extract_error_message(e)
  260. log.warning(f"Unexpected Qdrant error: {error_msg}")
  261. return False
  262. except Exception as e:
  263. # For any other errors, log and return False
  264. log.debug(f"Error checking collection {mt_collection}: {e}")
  265. return False
  266. def delete(
  267. self,
  268. collection_name: str,
  269. ids: Optional[list[str]] = None,
  270. filter: Optional[dict] = None,
  271. ):
  272. """
  273. Delete vectors by ID or filter from a collection with tenant isolation.
  274. """
  275. if not self.client:
  276. return None
  277. # Map to multi-tenant collection and tenant ID
  278. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  279. # Create tenant filter
  280. tenant_filter = models.FieldCondition(
  281. key="tenant_id", match=models.MatchValue(value=tenant_id)
  282. )
  283. must_conditions = [tenant_filter]
  284. should_conditions = []
  285. if ids:
  286. for id_value in ids:
  287. should_conditions.append(
  288. models.FieldCondition(
  289. key="metadata.id",
  290. match=models.MatchValue(value=id_value),
  291. ),
  292. )
  293. elif filter:
  294. for key, value in filter.items():
  295. must_conditions.append(
  296. models.FieldCondition(
  297. key=f"metadata.{key}",
  298. match=models.MatchValue(value=value),
  299. ),
  300. )
  301. try:
  302. # Try to delete directly - most of the time collection should exist
  303. update_result = self.client.delete(
  304. collection_name=mt_collection,
  305. points_selector=models.FilterSelector(
  306. filter=models.Filter(must=must_conditions, should=should_conditions)
  307. ),
  308. )
  309. return update_result
  310. except (UnexpectedResponse, grpc.RpcError) as e:
  311. if self._is_collection_not_found_error(e):
  312. log.debug(
  313. f"Collection {mt_collection} doesn't exist, nothing to delete"
  314. )
  315. return None
  316. else:
  317. # For other API errors, log and re-raise
  318. _, error_msg = self._extract_error_message(e)
  319. log.warning(f"Unexpected Qdrant error: {error_msg}")
  320. raise
  321. except Exception as e:
  322. # For non-Qdrant exceptions, re-raise
  323. raise
  324. def search(
  325. self, collection_name: str, vectors: list[list[float | int]], limit: int
  326. ) -> Optional[SearchResult]:
  327. """
  328. Search for the nearest neighbor items based on the vectors with tenant isolation.
  329. """
  330. if not self.client:
  331. return None
  332. # Map to multi-tenant collection and tenant ID
  333. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  334. # Get the vector dimension from the query vector
  335. dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
  336. try:
  337. # Try the search operation directly - most of the time collection should exist
  338. # Create tenant filter
  339. tenant_filter = models.FieldCondition(
  340. key="tenant_id", match=models.MatchValue(value=tenant_id)
  341. )
  342. # Ensure vector dimensions match the collection
  343. collection_dim = self.client.get_collection(
  344. mt_collection
  345. ).config.params.vectors.size
  346. if collection_dim != dimension:
  347. if collection_dim < dimension:
  348. vectors = [vector[:collection_dim] for vector in vectors]
  349. else:
  350. vectors = [
  351. vector + [0] * (collection_dim - dimension)
  352. for vector in vectors
  353. ]
  354. # Search with tenant filter
  355. prefetch_query = models.Prefetch(
  356. filter=models.Filter(must=[tenant_filter]),
  357. limit=NO_LIMIT,
  358. )
  359. query_response = self.client.query_points(
  360. collection_name=mt_collection,
  361. query=vectors[0],
  362. prefetch=prefetch_query,
  363. limit=limit,
  364. )
  365. get_result = self._result_to_get_result(query_response.points)
  366. return SearchResult(
  367. ids=get_result.ids,
  368. documents=get_result.documents,
  369. metadatas=get_result.metadatas,
  370. # qdrant distance is [-1, 1], normalize to [0, 1]
  371. distances=[
  372. [(point.score + 1.0) / 2.0 for point in query_response.points]
  373. ],
  374. )
  375. except (UnexpectedResponse, grpc.RpcError) as e:
  376. if self._is_collection_not_found_error(e):
  377. log.debug(
  378. f"Collection {mt_collection} doesn't exist, search returns None"
  379. )
  380. return None
  381. else:
  382. # For other API errors, log and re-raise
  383. _, error_msg = self._extract_error_message(e)
  384. log.warning(f"Unexpected Qdrant error during search: {error_msg}")
  385. raise
  386. except Exception as e:
  387. # For non-Qdrant exceptions, log and return None
  388. log.exception(f"Error searching collection '{collection_name}': {e}")
  389. return None
  390. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  391. """
  392. Query points with filters and tenant isolation.
  393. """
  394. if not self.client:
  395. return None
  396. # Map to multi-tenant collection and tenant ID
  397. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  398. # Set default limit if not provided
  399. if limit is None:
  400. limit = NO_LIMIT
  401. # Create tenant filter
  402. tenant_filter = models.FieldCondition(
  403. key="tenant_id", match=models.MatchValue(value=tenant_id)
  404. )
  405. # Create metadata filters
  406. field_conditions = []
  407. for key, value in filter.items():
  408. field_conditions.append(
  409. models.FieldCondition(
  410. key=f"metadata.{key}", match=models.MatchValue(value=value)
  411. )
  412. )
  413. # Combine tenant filter with metadata filters
  414. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  415. try:
  416. # Try the query directly - most of the time collection should exist
  417. points = self.client.query_points(
  418. collection_name=mt_collection,
  419. query_filter=combined_filter,
  420. limit=limit,
  421. )
  422. return self._result_to_get_result(points.points)
  423. except (UnexpectedResponse, grpc.RpcError) as e:
  424. if self._is_collection_not_found_error(e):
  425. log.debug(
  426. f"Collection {mt_collection} doesn't exist, query returns None"
  427. )
  428. return None
  429. else:
  430. # For other API errors, log and re-raise
  431. _, error_msg = self._extract_error_message(e)
  432. log.warning(f"Unexpected Qdrant error during query: {error_msg}")
  433. raise
  434. except Exception as e:
  435. # For non-Qdrant exceptions, log and re-raise
  436. log.exception(f"Error querying collection '{collection_name}': {e}")
  437. return None
  438. def get(self, collection_name: str) -> Optional[GetResult]:
  439. """
  440. Get all items in a collection with tenant isolation.
  441. """
  442. if not self.client:
  443. return None
  444. # Map to multi-tenant collection and tenant ID
  445. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  446. # Create tenant filter
  447. tenant_filter = models.FieldCondition(
  448. key="tenant_id", match=models.MatchValue(value=tenant_id)
  449. )
  450. try:
  451. # Try to get points directly - most of the time collection should exist
  452. points = self.client.query_points(
  453. collection_name=mt_collection,
  454. query_filter=models.Filter(must=[tenant_filter]),
  455. limit=NO_LIMIT,
  456. )
  457. return self._result_to_get_result(points.points)
  458. except (UnexpectedResponse, grpc.RpcError) as e:
  459. if self._is_collection_not_found_error(e):
  460. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  461. return None
  462. else:
  463. # For other API errors, log and re-raise
  464. _, error_msg = self._extract_error_message(e)
  465. log.warning(f"Unexpected Qdrant error during get: {error_msg}")
  466. raise
  467. except Exception as e:
  468. # For non-Qdrant exceptions, log and return None
  469. log.exception(f"Error getting collection '{collection_name}': {e}")
  470. return None
  471. def _handle_operation_with_error_retry(
  472. self, operation_name, mt_collection, points, dimension
  473. ):
  474. """
  475. Private helper to handle common error cases for insert and upsert operations.
  476. Args:
  477. operation_name: 'insert' or 'upsert'
  478. mt_collection: The multi-tenant collection name
  479. points: The vector points to insert/upsert
  480. dimension: The dimension of the vectors
  481. Returns:
  482. The operation result (for upsert) or None (for insert)
  483. """
  484. try:
  485. if operation_name == "insert":
  486. self.client.upload_points(mt_collection, points)
  487. return None
  488. else: # upsert
  489. return self.client.upsert(mt_collection, points)
  490. except (UnexpectedResponse, grpc.RpcError) as e:
  491. # Handle collection not found
  492. if self._is_collection_not_found_error(e):
  493. log.info(
  494. f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
  495. )
  496. # Create collection with correct dimensions from our vectors
  497. self._create_multi_tenant_collection_if_not_exists(
  498. mt_collection_name=mt_collection, dimension=dimension
  499. )
  500. # Try operation again - no need for dimension adjustment since we just created with correct dimensions
  501. if operation_name == "insert":
  502. self.client.upload_points(mt_collection, points)
  503. return None
  504. else: # upsert
  505. return self.client.upsert(mt_collection, points)
  506. # Handle dimension mismatch
  507. elif self._is_dimension_mismatch_error(e):
  508. # For dimension errors, the collection must exist, so get its configuration
  509. mt_collection_info = self.client.get_collection(mt_collection)
  510. existing_size = mt_collection_info.config.params.vectors.size
  511. log.info(
  512. f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
  513. )
  514. if existing_size < dimension:
  515. # Truncate vectors to fit
  516. log.info(
  517. f"Truncating vectors from {dimension} to {existing_size} dimensions"
  518. )
  519. points = [
  520. PointStruct(
  521. id=point.id,
  522. vector=point.vector[:existing_size],
  523. payload=point.payload,
  524. )
  525. for point in points
  526. ]
  527. elif existing_size > dimension:
  528. # Pad vectors with zeros
  529. log.info(
  530. f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
  531. )
  532. points = [
  533. PointStruct(
  534. id=point.id,
  535. vector=point.vector
  536. + [0] * (existing_size - len(point.vector)),
  537. payload=point.payload,
  538. )
  539. for point in points
  540. ]
  541. # Try operation again with adjusted dimensions
  542. if operation_name == "insert":
  543. self.client.upload_points(mt_collection, points)
  544. return None
  545. else: # upsert
  546. return self.client.upsert(mt_collection, points)
  547. else:
  548. # Not a known error we can handle, log and re-raise
  549. _, error_msg = self._extract_error_message(e)
  550. log.warning(f"Unhandled Qdrant error: {error_msg}")
  551. raise
  552. except Exception as e:
  553. # For non-Qdrant exceptions, re-raise
  554. raise
  555. def insert(self, collection_name: str, items: list[VectorItem]):
  556. """
  557. Insert items with tenant ID.
  558. """
  559. if not self.client or not items:
  560. return None
  561. # Map to multi-tenant collection and tenant ID
  562. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  563. # Get dimensions from the actual vectors
  564. dimension = len(items[0]["vector"]) if items else None
  565. # Create points with tenant ID
  566. points = self._create_points(items, tenant_id)
  567. # Handle the operation with error retry
  568. return self._handle_operation_with_error_retry(
  569. "insert", mt_collection, points, dimension
  570. )
  571. def upsert(self, collection_name: str, items: list[VectorItem]):
  572. """
  573. Upsert items with tenant ID.
  574. """
  575. if not self.client or not items:
  576. return None
  577. # Map to multi-tenant collection and tenant ID
  578. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  579. # Get dimensions from the actual vectors
  580. dimension = len(items[0]["vector"]) if items else None
  581. # Create points with tenant ID
  582. points = self._create_points(items, tenant_id)
  583. # Handle the operation with error retry
  584. return self._handle_operation_with_error_retry(
  585. "upsert", mt_collection, points, dimension
  586. )
  587. def reset(self):
  588. """
  589. Reset the database by deleting all collections.
  590. """
  591. if not self.client:
  592. return None
  593. collection_names = self.client.get_collections().collections
  594. for collection_name in collection_names:
  595. if collection_name.name.startswith(self.collection_prefix):
  596. self.client.delete_collection(collection_name=collection_name.name)
  597. def delete_collection(self, collection_name: str):
  598. """
  599. Delete a collection.
  600. """
  601. if not self.client:
  602. return None
  603. # Map to multi-tenant collection and tenant ID
  604. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  605. tenant_filter = models.FieldCondition(
  606. key="tenant_id", match=models.MatchValue(value=tenant_id)
  607. )
  608. field_conditions = [tenant_filter]
  609. update_result = self.client.delete(
  610. collection_name=mt_collection,
  611. points_selector=models.FilterSelector(
  612. filter=models.Filter(must=field_conditions)
  613. ),
  614. )
  615. if self.client.get_collection(mt_collection).points_count == 0:
  616. self.client.delete_collection(mt_collection)
  617. return update_result