pinecone.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import time # for measuring elapsed time
  4. from pinecone import ServerlessSpec
  5. import asyncio # for async upserts
  6. import functools # for partial binding in async tasks
  7. import concurrent.futures # for parallel batch upserts
  8. from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
  9. from open_webui.retrieval.vector.main import (
  10. VectorDBBase,
  11. VectorItem,
  12. SearchResult,
  13. GetResult,
  14. )
  15. from open_webui.config import (
  16. PINECONE_API_KEY,
  17. PINECONE_ENVIRONMENT,
  18. PINECONE_INDEX_NAME,
  19. PINECONE_DIMENSION,
  20. PINECONE_METRIC,
  21. PINECONE_CLOUD,
  22. )
  23. from open_webui.env import SRC_LOG_LEVELS
  24. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  25. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  26. log = logging.getLogger(__name__)
  27. log.setLevel(SRC_LOG_LEVELS["RAG"])
  28. class PineconeClient(VectorDBBase):
  29. def __init__(self):
  30. self.collection_prefix = "open-webui"
  31. # Validate required configuration
  32. self._validate_config()
  33. # Store configuration values
  34. self.api_key = PINECONE_API_KEY
  35. self.environment = PINECONE_ENVIRONMENT
  36. self.index_name = PINECONE_INDEX_NAME
  37. self.dimension = PINECONE_DIMENSION
  38. self.metric = PINECONE_METRIC
  39. self.cloud = PINECONE_CLOUD
  40. # Initialize Pinecone gRPC client for improved performance
  41. self.client = PineconeGRPC(
  42. api_key=self.api_key, environment=self.environment, cloud=self.cloud
  43. )
  44. # Persistent executor for batch operations
  45. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
  46. # Create index if it doesn't exist
  47. self._initialize_index()
  48. def _validate_config(self) -> None:
  49. """Validate that all required configuration variables are set."""
  50. missing_vars = []
  51. if not PINECONE_API_KEY:
  52. missing_vars.append("PINECONE_API_KEY")
  53. if not PINECONE_ENVIRONMENT:
  54. missing_vars.append("PINECONE_ENVIRONMENT")
  55. if not PINECONE_INDEX_NAME:
  56. missing_vars.append("PINECONE_INDEX_NAME")
  57. if not PINECONE_DIMENSION:
  58. missing_vars.append("PINECONE_DIMENSION")
  59. if not PINECONE_CLOUD:
  60. missing_vars.append("PINECONE_CLOUD")
  61. if missing_vars:
  62. raise ValueError(
  63. f"Required configuration missing: {', '.join(missing_vars)}"
  64. )
  65. def _initialize_index(self) -> None:
  66. """Initialize the Pinecone index."""
  67. try:
  68. # Check if index exists
  69. if self.index_name not in self.client.list_indexes().names():
  70. log.info(f"Creating Pinecone index '{self.index_name}'...")
  71. self.client.create_index(
  72. name=self.index_name,
  73. dimension=self.dimension,
  74. metric=self.metric,
  75. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  76. )
  77. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  78. else:
  79. log.info(f"Using existing Pinecone index '{self.index_name}'")
  80. # Connect to the index
  81. self.index = self.client.Index(self.index_name)
  82. except Exception as e:
  83. log.error(f"Failed to initialize Pinecone index: {e}")
  84. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  85. def _create_points(
  86. self, items: List[VectorItem], collection_name_with_prefix: str
  87. ) -> List[Dict[str, Any]]:
  88. """Convert VectorItem objects to Pinecone point format."""
  89. points = []
  90. for item in items:
  91. # Start with any existing metadata or an empty dict
  92. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  93. # Add text to metadata if available
  94. if "text" in item:
  95. metadata["text"] = item["text"]
  96. # Always add collection_name to metadata for filtering
  97. metadata["collection_name"] = collection_name_with_prefix
  98. point = {
  99. "id": item["id"],
  100. "values": item["vector"],
  101. "metadata": metadata,
  102. }
  103. points.append(point)
  104. return points
  105. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  106. """Get the collection name with prefix."""
  107. return f"{self.collection_prefix}_{collection_name}"
  108. def _normalize_distance(self, score: float) -> float:
  109. """Normalize distance score based on the metric used."""
  110. if self.metric.lower() == "cosine":
  111. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  112. return (score + 1.0) / 2.0
  113. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  114. # These are already suitable for ranking (smaller is better for Euclidean)
  115. return score
  116. else:
  117. # For other metrics, use as is
  118. return score
  119. def _result_to_get_result(self, matches: list) -> GetResult:
  120. """Convert Pinecone matches to GetResult format."""
  121. ids = []
  122. documents = []
  123. metadatas = []
  124. for match in matches:
  125. metadata = match.get("metadata", {})
  126. ids.append(match["id"])
  127. documents.append(metadata.get("text", ""))
  128. metadatas.append(metadata)
  129. return GetResult(
  130. **{
  131. "ids": [ids],
  132. "documents": [documents],
  133. "metadatas": [metadatas],
  134. }
  135. )
  136. def has_collection(self, collection_name: str) -> bool:
  137. """Check if a collection exists by searching for at least one item."""
  138. collection_name_with_prefix = self._get_collection_name_with_prefix(
  139. collection_name
  140. )
  141. try:
  142. # Search for at least 1 item with this collection name in metadata
  143. response = self.index.query(
  144. vector=[0.0] * self.dimension, # dummy vector
  145. top_k=1,
  146. filter={"collection_name": collection_name_with_prefix},
  147. include_metadata=False,
  148. )
  149. return len(response.matches) > 0
  150. except Exception as e:
  151. log.exception(
  152. f"Error checking collection '{collection_name_with_prefix}': {e}"
  153. )
  154. return False
  155. def delete_collection(self, collection_name: str) -> None:
  156. """Delete a collection by removing all vectors with the collection name in metadata."""
  157. collection_name_with_prefix = self._get_collection_name_with_prefix(
  158. collection_name
  159. )
  160. try:
  161. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  162. log.info(
  163. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  164. )
  165. except Exception as e:
  166. log.warning(
  167. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  168. )
  169. raise
  170. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  171. """Insert vectors into a collection."""
  172. if not items:
  173. log.warning("No items to insert")
  174. return
  175. start_time = time.time()
  176. collection_name_with_prefix = self._get_collection_name_with_prefix(
  177. collection_name
  178. )
  179. points = self._create_points(items, collection_name_with_prefix)
  180. # Parallelize batch inserts for performance
  181. executor = self._executor
  182. futures = []
  183. for i in range(0, len(points), BATCH_SIZE):
  184. batch = points[i : i + BATCH_SIZE]
  185. futures.append(executor.submit(self.index.upsert, vectors=batch))
  186. for future in concurrent.futures.as_completed(futures):
  187. try:
  188. future.result()
  189. except Exception as e:
  190. log.error(f"Error inserting batch: {e}")
  191. raise
  192. elapsed = time.time() - start_time
  193. log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
  194. log.info(
  195. f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  196. )
  197. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  198. """Upsert (insert or update) vectors into a collection."""
  199. if not items:
  200. log.warning("No items to upsert")
  201. return
  202. start_time = time.time()
  203. collection_name_with_prefix = self._get_collection_name_with_prefix(
  204. collection_name
  205. )
  206. points = self._create_points(items, collection_name_with_prefix)
  207. # Parallelize batch upserts for performance
  208. executor = self._executor
  209. futures = []
  210. for i in range(0, len(points), BATCH_SIZE):
  211. batch = points[i : i + BATCH_SIZE]
  212. futures.append(executor.submit(self.index.upsert, vectors=batch))
  213. for future in concurrent.futures.as_completed(futures):
  214. try:
  215. future.result()
  216. except Exception as e:
  217. log.error(f"Error upserting batch: {e}")
  218. raise
  219. elapsed = time.time() - start_time
  220. log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
  221. log.info(
  222. f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  223. )
  224. async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  225. """Async version of insert using asyncio and run_in_executor for improved performance."""
  226. if not items:
  227. log.warning("No items to insert")
  228. return
  229. collection_name_with_prefix = self._get_collection_name_with_prefix(
  230. collection_name
  231. )
  232. points = self._create_points(items, collection_name_with_prefix)
  233. # Create batches
  234. batches = [
  235. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  236. ]
  237. loop = asyncio.get_event_loop()
  238. tasks = [
  239. loop.run_in_executor(
  240. None, functools.partial(self.index.upsert, vectors=batch)
  241. )
  242. for batch in batches
  243. ]
  244. results = await asyncio.gather(*tasks, return_exceptions=True)
  245. for result in results:
  246. if isinstance(result, Exception):
  247. log.error(f"Error in async insert batch: {result}")
  248. raise result
  249. log.info(
  250. f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  251. )
  252. async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  253. """Async version of upsert using asyncio and run_in_executor for improved performance."""
  254. if not items:
  255. log.warning("No items to upsert")
  256. return
  257. collection_name_with_prefix = self._get_collection_name_with_prefix(
  258. collection_name
  259. )
  260. points = self._create_points(items, collection_name_with_prefix)
  261. # Create batches
  262. batches = [
  263. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  264. ]
  265. loop = asyncio.get_event_loop()
  266. tasks = [
  267. loop.run_in_executor(
  268. None, functools.partial(self.index.upsert, vectors=batch)
  269. )
  270. for batch in batches
  271. ]
  272. results = await asyncio.gather(*tasks, return_exceptions=True)
  273. for result in results:
  274. if isinstance(result, Exception):
  275. log.error(f"Error in async upsert batch: {result}")
  276. raise result
  277. log.info(
  278. f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  279. )
  280. def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  281. """Perform a streaming upsert over gRPC for performance testing."""
  282. if not items:
  283. log.warning("No items to upsert via streaming")
  284. return
  285. collection_name_with_prefix = self._get_collection_name_with_prefix(
  286. collection_name
  287. )
  288. points = self._create_points(items, collection_name_with_prefix)
  289. # Open a streaming upsert channel
  290. stream = self.index.streaming_upsert()
  291. try:
  292. for point in points:
  293. # send each point over the stream
  294. stream.send(point)
  295. # close the stream to finalize
  296. stream.close()
  297. log.info(
  298. f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
  299. )
  300. except Exception as e:
  301. log.error(f"Error during streaming upsert: {e}")
  302. raise
  303. def search(
  304. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  305. ) -> Optional[SearchResult]:
  306. """Search for similar vectors in a collection."""
  307. if not vectors or not vectors[0]:
  308. log.warning("No vectors provided for search")
  309. return None
  310. collection_name_with_prefix = self._get_collection_name_with_prefix(
  311. collection_name
  312. )
  313. if limit is None or limit <= 0:
  314. limit = NO_LIMIT
  315. try:
  316. # Search using the first vector (assuming this is the intended behavior)
  317. query_vector = vectors[0]
  318. # Perform the search
  319. query_response = self.index.query(
  320. vector=query_vector,
  321. top_k=limit,
  322. include_metadata=True,
  323. filter={"collection_name": collection_name_with_prefix},
  324. )
  325. if not query_response.matches:
  326. # Return empty result if no matches
  327. return SearchResult(
  328. ids=[[]],
  329. documents=[[]],
  330. metadatas=[[]],
  331. distances=[[]],
  332. )
  333. # Convert to GetResult format
  334. get_result = self._result_to_get_result(query_response.matches)
  335. # Calculate normalized distances based on metric
  336. distances = [
  337. [
  338. self._normalize_distance(match.score)
  339. for match in query_response.matches
  340. ]
  341. ]
  342. return SearchResult(
  343. ids=get_result.ids,
  344. documents=get_result.documents,
  345. metadatas=get_result.metadatas,
  346. distances=distances,
  347. )
  348. except Exception as e:
  349. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  350. return None
  351. def query(
  352. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  353. ) -> Optional[GetResult]:
  354. """Query vectors by metadata filter."""
  355. collection_name_with_prefix = self._get_collection_name_with_prefix(
  356. collection_name
  357. )
  358. if limit is None or limit <= 0:
  359. limit = NO_LIMIT
  360. try:
  361. # Create a zero vector for the dimension as Pinecone requires a vector
  362. zero_vector = [0.0] * self.dimension
  363. # Combine user filter with collection_name
  364. pinecone_filter = {"collection_name": collection_name_with_prefix}
  365. if filter:
  366. pinecone_filter.update(filter)
  367. # Perform metadata-only query
  368. query_response = self.index.query(
  369. vector=zero_vector,
  370. filter=pinecone_filter,
  371. top_k=limit,
  372. include_metadata=True,
  373. )
  374. return self._result_to_get_result(query_response.matches)
  375. except Exception as e:
  376. log.error(f"Error querying collection '{collection_name}': {e}")
  377. return None
  378. def get(self, collection_name: str) -> Optional[GetResult]:
  379. """Get all vectors in a collection."""
  380. collection_name_with_prefix = self._get_collection_name_with_prefix(
  381. collection_name
  382. )
  383. try:
  384. # Use a zero vector for fetching all entries
  385. zero_vector = [0.0] * self.dimension
  386. # Add filter to only get vectors for this collection
  387. query_response = self.index.query(
  388. vector=zero_vector,
  389. top_k=NO_LIMIT,
  390. include_metadata=True,
  391. filter={"collection_name": collection_name_with_prefix},
  392. )
  393. return self._result_to_get_result(query_response.matches)
  394. except Exception as e:
  395. log.error(f"Error getting collection '{collection_name}': {e}")
  396. return None
  397. def delete(
  398. self,
  399. collection_name: str,
  400. ids: Optional[List[str]] = None,
  401. filter: Optional[Dict] = None,
  402. ) -> None:
  403. """Delete vectors by IDs or filter."""
  404. collection_name_with_prefix = self._get_collection_name_with_prefix(
  405. collection_name
  406. )
  407. try:
  408. if ids:
  409. # Delete by IDs (in batches for large deletions)
  410. for i in range(0, len(ids), BATCH_SIZE):
  411. batch_ids = ids[i : i + BATCH_SIZE]
  412. # Note: When deleting by ID, we can't filter by collection_name
  413. # This is a limitation of Pinecone - be careful with ID uniqueness
  414. self.index.delete(ids=batch_ids)
  415. log.debug(
  416. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  417. )
  418. log.info(
  419. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  420. )
  421. elif filter:
  422. # Combine user filter with collection_name
  423. pinecone_filter = {"collection_name": collection_name_with_prefix}
  424. if filter:
  425. pinecone_filter.update(filter)
  426. # Delete by metadata filter
  427. self.index.delete(filter=pinecone_filter)
  428. log.info(
  429. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  430. )
  431. else:
  432. log.warning("No ids or filter provided for delete operation")
  433. except Exception as e:
  434. log.error(f"Error deleting from collection '{collection_name}': {e}")
  435. raise
  436. def reset(self) -> None:
  437. """Reset the database by deleting all collections."""
  438. try:
  439. self.index.delete(delete_all=True)
  440. log.info("All vectors successfully deleted from the index.")
  441. except Exception as e:
  442. log.error(f"Failed to reset Pinecone index: {e}")
  443. raise
  444. def close(self):
  445. """Shut down the gRPC channel and thread pool."""
  446. try:
  447. self.client.close()
  448. log.info("Pinecone gRPC channel closed.")
  449. except Exception as e:
  450. log.warning(f"Failed to close Pinecone gRPC channel: {e}")
  451. self._executor.shutdown(wait=True)
  452. def __enter__(self):
  453. """Enter context manager."""
  454. return self
  455. def __exit__(self, exc_type, exc_val, exc_tb):
  456. """Exit context manager, ensuring resources are cleaned up."""
  457. self.close()