qdrant_multitenancy.py 26 KB


  1. import logging
  2. from typing import Optional, Tuple
  3. from urllib.parse import urlparse
  4. import grpc
  5. from open_webui.config import (
  6. QDRANT_API_KEY,
  7. QDRANT_GRPC_PORT,
  8. QDRANT_ON_DISK,
  9. QDRANT_PREFER_GRPC,
  10. QDRANT_URI,
  11. QDRANT_COLLECTION_PREFIX,
  12. )
  13. from open_webui.env import SRC_LOG_LEVELS
  14. from open_webui.retrieval.vector.main import (
  15. GetResult,
  16. SearchResult,
  17. VectorDBBase,
  18. VectorItem,
  19. )
  20. from qdrant_client import QdrantClient as Qclient
  21. from qdrant_client.http.exceptions import UnexpectedResponse
  22. from qdrant_client.http.models import PointStruct
  23. from qdrant_client.models import models
  24. NO_LIMIT = 999999999
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. class QdrantClient(VectorDBBase):
  28. def __init__(self):
  29. self.collection_prefix = QDRANT_COLLECTION_PREFIX
  30. self.QDRANT_URI = QDRANT_URI
  31. self.QDRANT_API_KEY = QDRANT_API_KEY
  32. self.QDRANT_ON_DISK = QDRANT_ON_DISK
  33. self.PREFER_GRPC = QDRANT_PREFER_GRPC
  34. self.GRPC_PORT = QDRANT_GRPC_PORT
  35. if not self.QDRANT_URI:
  36. self.client = None
  37. return
  38. # Unified handling for either scheme
  39. parsed = urlparse(self.QDRANT_URI)
  40. host = parsed.hostname or self.QDRANT_URI
  41. http_port = parsed.port or 6333 # default REST port
  42. if self.PREFER_GRPC:
  43. self.client = Qclient(
  44. host=host,
  45. port=http_port,
  46. grpc_port=self.GRPC_PORT,
  47. prefer_grpc=self.PREFER_GRPC,
  48. api_key=self.QDRANT_API_KEY,
  49. )
  50. else:
  51. self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
  52. # Main collection types for multi-tenancy
  53. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  54. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  55. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  56. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
  57. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
  58. def _result_to_get_result(self, points) -> GetResult:
  59. ids = []
  60. documents = []
  61. metadatas = []
  62. for point in points:
  63. payload = point.payload
  64. ids.append(point.id)
  65. documents.append(payload["text"])
  66. metadatas.append(payload["metadata"])
  67. return GetResult(
  68. **{
  69. "ids": [ids],
  70. "documents": [documents],
  71. "metadatas": [metadatas],
  72. }
  73. )
  74. def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
  75. """
  76. Maps the traditional collection name to multi-tenant collection and tenant ID.
  77. Returns:
  78. tuple: (collection_name, tenant_id)
  79. """
  80. # Check for user memory collections
  81. tenant_id = collection_name
  82. if collection_name.startswith("user-memory-"):
  83. return self.MEMORY_COLLECTION, tenant_id
  84. # Check for file collections
  85. elif collection_name.startswith("file-"):
  86. return self.FILE_COLLECTION, tenant_id
  87. # Check for web search collections
  88. elif collection_name.startswith("web-search-"):
  89. return self.WEB_SEARCH_COLLECTION, tenant_id
  90. # Handle hash-based collections (YouTube and web URLs)
  91. elif len(collection_name) == 63 and all(
  92. c in "0123456789abcdef" for c in collection_name
  93. ):
  94. return self.HASH_BASED_COLLECTION, tenant_id
  95. else:
  96. return self.KNOWLEDGE_COLLECTION, tenant_id
  97. def _extract_error_message(self, exception):
  98. """
  99. Extract error message from either HTTP or gRPC exceptions
  100. Returns:
  101. tuple: (status_code, error_message)
  102. """
  103. # Check if it's an HTTP exception
  104. if isinstance(exception, UnexpectedResponse):
  105. try:
  106. error_data = exception.structured()
  107. error_msg = error_data.get("status", {}).get("error", "")
  108. return exception.status_code, error_msg
  109. except Exception as inner_e:
  110. log.error(f"Failed to parse HTTP error: {inner_e}")
  111. return exception.status_code, str(exception)
  112. # Check if it's a gRPC exception
  113. elif isinstance(exception, grpc.RpcError):
  114. # Extract status code from gRPC error
  115. status_code = None
  116. if hasattr(exception, "code") and callable(exception.code):
  117. status_code = exception.code().value[0]
  118. # Extract error message
  119. error_msg = str(exception)
  120. if "details =" in error_msg:
  121. # Parse the details line which contains the actual error message
  122. try:
  123. details_line = [
  124. line.strip()
  125. for line in error_msg.split("\n")
  126. if "details =" in line
  127. ][0]
  128. error_msg = details_line.split("details =")[1].strip(' "')
  129. except (IndexError, AttributeError):
  130. # Fall back to full message if parsing fails
  131. pass
  132. return status_code, error_msg
  133. # For any other type of exception
  134. return None, str(exception)
  135. def _is_collection_not_found_error(self, exception):
  136. """
  137. Check if the exception is due to collection not found, supporting both HTTP and gRPC
  138. """
  139. status_code, error_msg = self._extract_error_message(exception)
  140. # HTTP error (404)
  141. if (
  142. status_code == 404
  143. and "Collection" in error_msg
  144. and "doesn't exist" in error_msg
  145. ):
  146. return True
  147. # gRPC error (NOT_FOUND status)
  148. if (
  149. isinstance(exception, grpc.RpcError)
  150. and exception.code() == grpc.StatusCode.NOT_FOUND
  151. ):
  152. return True
  153. return False
  154. def _is_dimension_mismatch_error(self, exception):
  155. """
  156. Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
  157. """
  158. status_code, error_msg = self._extract_error_message(exception)
  159. # Common patterns in both HTTP and gRPC
  160. return (
  161. "Vector dimension error" in error_msg
  162. or "dimensions mismatch" in error_msg
  163. or "invalid vector size" in error_msg
  164. )
  165. def _create_multi_tenant_collection_if_not_exists(
  166. self, mt_collection_name: str, dimension: int = 384
  167. ):
  168. """
  169. Creates a collection with multi-tenancy configuration if it doesn't exist.
  170. Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
  171. When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
  172. """
  173. try:
  174. # Try to create the collection directly - will fail if it already exists
  175. self.client.create_collection(
  176. collection_name=mt_collection_name,
  177. vectors_config=models.VectorParams(
  178. size=dimension,
  179. distance=models.Distance.COSINE,
  180. on_disk=self.QDRANT_ON_DISK,
  181. ),
  182. hnsw_config=models.HnswConfigDiff(
  183. payload_m=16, # Enable per-tenant indexing
  184. m=0,
  185. on_disk=self.QDRANT_ON_DISK,
  186. ),
  187. )
  188. # Create tenant ID payload index
  189. self.client.create_payload_index(
  190. collection_name=mt_collection_name,
  191. field_name="tenant_id",
  192. field_schema=models.KeywordIndexParams(
  193. type=models.KeywordIndexType.KEYWORD,
  194. is_tenant=True,
  195. on_disk=self.QDRANT_ON_DISK,
  196. ),
  197. wait=True,
  198. )
  199. log.info(
  200. f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
  201. )
  202. except (UnexpectedResponse, grpc.RpcError) as e:
  203. # Check for the specific error indicating collection already exists
  204. status_code, error_msg = self._extract_error_message(e)
  205. # HTTP status code 409 or gRPC ALREADY_EXISTS
  206. if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
  207. isinstance(e, grpc.RpcError)
  208. and e.code() == grpc.StatusCode.ALREADY_EXISTS
  209. ):
  210. if "already exists" in error_msg:
  211. log.debug(f"Collection {mt_collection_name} already exists")
  212. return
  213. # If it's not an already exists error, re-raise
  214. raise e
  215. except Exception as e:
  216. raise e
  217. def _create_points(self, items: list[VectorItem], tenant_id: str):
  218. """
  219. Create point structs from vector items with tenant ID.
  220. """
  221. return [
  222. PointStruct(
  223. id=item["id"],
  224. vector=item["vector"],
  225. payload={
  226. "text": item["text"],
  227. "metadata": item["metadata"],
  228. "tenant_id": tenant_id,
  229. },
  230. )
  231. for item in items
  232. ]
  233. def has_collection(self, collection_name: str) -> bool:
  234. """
  235. Check if a logical collection exists by checking for any points with the tenant ID.
  236. """
  237. if not self.client:
  238. return False
  239. # Map to multi-tenant collection and tenant ID
  240. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  241. # Create tenant filter
  242. tenant_filter = models.FieldCondition(
  243. key="tenant_id", match=models.MatchValue(value=tenant_id)
  244. )
  245. try:
  246. # Try directly querying - most of the time collection should exist
  247. response = self.client.query_points(
  248. collection_name=mt_collection,
  249. query_filter=models.Filter(must=[tenant_filter]),
  250. limit=1,
  251. )
  252. # Collection exists with this tenant ID if there are points
  253. return len(response.points) > 0
  254. except (UnexpectedResponse, grpc.RpcError) as e:
  255. if self._is_collection_not_found_error(e):
  256. log.debug(f"Collection {mt_collection} doesn't exist")
  257. return False
  258. else:
  259. # For other API errors, log and return False
  260. _, error_msg = self._extract_error_message(e)
  261. log.warning(f"Unexpected Qdrant error: {error_msg}")
  262. return False
  263. except Exception as e:
  264. # For any other errors, log and return False
  265. log.debug(f"Error checking collection {mt_collection}: {e}")
  266. return False
  267. def delete(
  268. self,
  269. collection_name: str,
  270. ids: Optional[list[str]] = None,
  271. filter: Optional[dict] = None,
  272. ):
  273. """
  274. Delete vectors by ID or filter from a collection with tenant isolation.
  275. """
  276. if not self.client:
  277. return None
  278. # Map to multi-tenant collection and tenant ID
  279. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  280. # Create tenant filter
  281. tenant_filter = models.FieldCondition(
  282. key="tenant_id", match=models.MatchValue(value=tenant_id)
  283. )
  284. must_conditions = [tenant_filter]
  285. should_conditions = []
  286. if ids:
  287. for id_value in ids:
  288. should_conditions.append(
  289. models.FieldCondition(
  290. key="metadata.id",
  291. match=models.MatchValue(value=id_value),
  292. ),
  293. )
  294. elif filter:
  295. for key, value in filter.items():
  296. must_conditions.append(
  297. models.FieldCondition(
  298. key=f"metadata.{key}",
  299. match=models.MatchValue(value=value),
  300. ),
  301. )
  302. try:
  303. # Try to delete directly - most of the time collection should exist
  304. update_result = self.client.delete(
  305. collection_name=mt_collection,
  306. points_selector=models.FilterSelector(
  307. filter=models.Filter(must=must_conditions, should=should_conditions)
  308. ),
  309. )
  310. return update_result
  311. except (UnexpectedResponse, grpc.RpcError) as e:
  312. if self._is_collection_not_found_error(e):
  313. log.debug(
  314. f"Collection {mt_collection} doesn't exist, nothing to delete"
  315. )
  316. return None
  317. else:
  318. # For other API errors, log and re-raise
  319. _, error_msg = self._extract_error_message(e)
  320. log.warning(f"Unexpected Qdrant error: {error_msg}")
  321. raise
  322. except Exception as e:
  323. # For non-Qdrant exceptions, re-raise
  324. raise
  325. def search(
  326. self, collection_name: str, vectors: list[list[float | int]], limit: int
  327. ) -> Optional[SearchResult]:
  328. """
  329. Search for the nearest neighbor items based on the vectors with tenant isolation.
  330. """
  331. if not self.client:
  332. return None
  333. # Map to multi-tenant collection and tenant ID
  334. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  335. # Get the vector dimension from the query vector
  336. dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
  337. try:
  338. # Try the search operation directly - most of the time collection should exist
  339. # Create tenant filter
  340. tenant_filter = models.FieldCondition(
  341. key="tenant_id", match=models.MatchValue(value=tenant_id)
  342. )
  343. # Ensure vector dimensions match the collection
  344. collection_dim = self.client.get_collection(
  345. mt_collection
  346. ).config.params.vectors.size
  347. if collection_dim != dimension:
  348. if collection_dim < dimension:
  349. vectors = [vector[:collection_dim] for vector in vectors]
  350. else:
  351. vectors = [
  352. vector + [0] * (collection_dim - dimension)
  353. for vector in vectors
  354. ]
  355. # Search with tenant filter
  356. prefetch_query = models.Prefetch(
  357. filter=models.Filter(must=[tenant_filter]),
  358. limit=NO_LIMIT,
  359. )
  360. query_response = self.client.query_points(
  361. collection_name=mt_collection,
  362. query=vectors[0],
  363. prefetch=prefetch_query,
  364. limit=limit,
  365. )
  366. get_result = self._result_to_get_result(query_response.points)
  367. return SearchResult(
  368. ids=get_result.ids,
  369. documents=get_result.documents,
  370. metadatas=get_result.metadatas,
  371. # qdrant distance is [-1, 1], normalize to [0, 1]
  372. distances=[
  373. [(point.score + 1.0) / 2.0 for point in query_response.points]
  374. ],
  375. )
  376. except (UnexpectedResponse, grpc.RpcError) as e:
  377. if self._is_collection_not_found_error(e):
  378. log.debug(
  379. f"Collection {mt_collection} doesn't exist, search returns None"
  380. )
  381. return None
  382. else:
  383. # For other API errors, log and re-raise
  384. _, error_msg = self._extract_error_message(e)
  385. log.warning(f"Unexpected Qdrant error during search: {error_msg}")
  386. raise
  387. except Exception as e:
  388. # For non-Qdrant exceptions, log and return None
  389. log.exception(f"Error searching collection '{collection_name}': {e}")
  390. return None
  391. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  392. """
  393. Query points with filters and tenant isolation.
  394. """
  395. if not self.client:
  396. return None
  397. # Map to multi-tenant collection and tenant ID
  398. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  399. # Set default limit if not provided
  400. if limit is None:
  401. limit = NO_LIMIT
  402. # Create tenant filter
  403. tenant_filter = models.FieldCondition(
  404. key="tenant_id", match=models.MatchValue(value=tenant_id)
  405. )
  406. # Create metadata filters
  407. field_conditions = []
  408. for key, value in filter.items():
  409. field_conditions.append(
  410. models.FieldCondition(
  411. key=f"metadata.{key}", match=models.MatchValue(value=value)
  412. )
  413. )
  414. # Combine tenant filter with metadata filters
  415. combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
  416. try:
  417. # Try the query directly - most of the time collection should exist
  418. points = self.client.query_points(
  419. collection_name=mt_collection,
  420. query_filter=combined_filter,
  421. limit=limit,
  422. )
  423. return self._result_to_get_result(points.points)
  424. except (UnexpectedResponse, grpc.RpcError) as e:
  425. if self._is_collection_not_found_error(e):
  426. log.debug(
  427. f"Collection {mt_collection} doesn't exist, query returns None"
  428. )
  429. return None
  430. else:
  431. # For other API errors, log and re-raise
  432. _, error_msg = self._extract_error_message(e)
  433. log.warning(f"Unexpected Qdrant error during query: {error_msg}")
  434. raise
  435. except Exception as e:
  436. # For non-Qdrant exceptions, log and re-raise
  437. log.exception(f"Error querying collection '{collection_name}': {e}")
  438. return None
  439. def get(self, collection_name: str) -> Optional[GetResult]:
  440. """
  441. Get all items in a collection with tenant isolation.
  442. """
  443. if not self.client:
  444. return None
  445. # Map to multi-tenant collection and tenant ID
  446. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  447. # Create tenant filter
  448. tenant_filter = models.FieldCondition(
  449. key="tenant_id", match=models.MatchValue(value=tenant_id)
  450. )
  451. try:
  452. # Try to get points directly - most of the time collection should exist
  453. points = self.client.query_points(
  454. collection_name=mt_collection,
  455. query_filter=models.Filter(must=[tenant_filter]),
  456. limit=NO_LIMIT,
  457. )
  458. return self._result_to_get_result(points.points)
  459. except (UnexpectedResponse, grpc.RpcError) as e:
  460. if self._is_collection_not_found_error(e):
  461. log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
  462. return None
  463. else:
  464. # For other API errors, log and re-raise
  465. _, error_msg = self._extract_error_message(e)
  466. log.warning(f"Unexpected Qdrant error during get: {error_msg}")
  467. raise
  468. except Exception as e:
  469. # For non-Qdrant exceptions, log and return None
  470. log.exception(f"Error getting collection '{collection_name}': {e}")
  471. return None
  472. def _handle_operation_with_error_retry(
  473. self, operation_name, mt_collection, points, dimension
  474. ):
  475. """
  476. Private helper to handle common error cases for insert and upsert operations.
  477. Args:
  478. operation_name: 'insert' or 'upsert'
  479. mt_collection: The multi-tenant collection name
  480. points: The vector points to insert/upsert
  481. dimension: The dimension of the vectors
  482. Returns:
  483. The operation result (for upsert) or None (for insert)
  484. """
  485. try:
  486. if operation_name == "insert":
  487. self.client.upload_points(mt_collection, points)
  488. return None
  489. else: # upsert
  490. return self.client.upsert(mt_collection, points)
  491. except (UnexpectedResponse, grpc.RpcError) as e:
  492. # Handle collection not found
  493. if self._is_collection_not_found_error(e):
  494. log.info(
  495. f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
  496. )
  497. # Create collection with correct dimensions from our vectors
  498. self._create_multi_tenant_collection_if_not_exists(
  499. mt_collection_name=mt_collection, dimension=dimension
  500. )
  501. # Try operation again - no need for dimension adjustment since we just created with correct dimensions
  502. if operation_name == "insert":
  503. self.client.upload_points(mt_collection, points)
  504. return None
  505. else: # upsert
  506. return self.client.upsert(mt_collection, points)
  507. # Handle dimension mismatch
  508. elif self._is_dimension_mismatch_error(e):
  509. # For dimension errors, the collection must exist, so get its configuration
  510. mt_collection_info = self.client.get_collection(mt_collection)
  511. existing_size = mt_collection_info.config.params.vectors.size
  512. log.info(
  513. f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
  514. )
  515. if existing_size < dimension:
  516. # Truncate vectors to fit
  517. log.info(
  518. f"Truncating vectors from {dimension} to {existing_size} dimensions"
  519. )
  520. points = [
  521. PointStruct(
  522. id=point.id,
  523. vector=point.vector[:existing_size],
  524. payload=point.payload,
  525. )
  526. for point in points
  527. ]
  528. elif existing_size > dimension:
  529. # Pad vectors with zeros
  530. log.info(
  531. f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
  532. )
  533. points = [
  534. PointStruct(
  535. id=point.id,
  536. vector=point.vector
  537. + [0] * (existing_size - len(point.vector)),
  538. payload=point.payload,
  539. )
  540. for point in points
  541. ]
  542. # Try operation again with adjusted dimensions
  543. if operation_name == "insert":
  544. self.client.upload_points(mt_collection, points)
  545. return None
  546. else: # upsert
  547. return self.client.upsert(mt_collection, points)
  548. else:
  549. # Not a known error we can handle, log and re-raise
  550. _, error_msg = self._extract_error_message(e)
  551. log.warning(f"Unhandled Qdrant error: {error_msg}")
  552. raise
  553. except Exception as e:
  554. # For non-Qdrant exceptions, re-raise
  555. raise
  556. def insert(self, collection_name: str, items: list[VectorItem]):
  557. """
  558. Insert items with tenant ID.
  559. """
  560. if not self.client or not items:
  561. return None
  562. # Map to multi-tenant collection and tenant ID
  563. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  564. # Get dimensions from the actual vectors
  565. dimension = len(items[0]["vector"]) if items else None
  566. # Create points with tenant ID
  567. points = self._create_points(items, tenant_id)
  568. # Handle the operation with error retry
  569. return self._handle_operation_with_error_retry(
  570. "insert", mt_collection, points, dimension
  571. )
  572. def upsert(self, collection_name: str, items: list[VectorItem]):
  573. """
  574. Upsert items with tenant ID.
  575. """
  576. if not self.client or not items:
  577. return None
  578. # Map to multi-tenant collection and tenant ID
  579. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  580. # Get dimensions from the actual vectors
  581. dimension = len(items[0]["vector"]) if items else None
  582. # Create points with tenant ID
  583. points = self._create_points(items, tenant_id)
  584. # Handle the operation with error retry
  585. return self._handle_operation_with_error_retry(
  586. "upsert", mt_collection, points, dimension
  587. )
  588. def reset(self):
  589. """
  590. Reset the database by deleting all collections.
  591. """
  592. if not self.client:
  593. return None
  594. collection_names = self.client.get_collections().collections
  595. for collection_name in collection_names:
  596. if collection_name.name.startswith(self.collection_prefix):
  597. self.client.delete_collection(collection_name=collection_name.name)
  598. def delete_collection(self, collection_name: str):
  599. """
  600. Delete a collection.
  601. """
  602. if not self.client:
  603. return None
  604. # Map to multi-tenant collection and tenant ID
  605. mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
  606. tenant_filter = models.FieldCondition(
  607. key="tenant_id", match=models.MatchValue(value=tenant_id)
  608. )
  609. field_conditions = [tenant_filter]
  610. update_result = self.client.delete(
  611. collection_name=mt_collection,
  612. points_selector=models.FilterSelector(
  613. filter=models.Filter(must=field_conditions)
  614. ),
  615. )
  616. if self.client.get_collection(mt_collection).points_count == 0:
  617. self.client.delete_collection(mt_collection)
  618. return update_result