pinecone.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import time # for measuring elapsed time
  4. from pinecone import Pinecone, ServerlessSpec
  5. import asyncio # for async upserts
  6. import functools # for partial binding in async tasks
  7. import concurrent.futures # for parallel batch upserts
  8. from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
  9. from open_webui.retrieval.vector.main import (
  10. VectorDBBase,
  11. VectorItem,
  12. SearchResult,
  13. GetResult,
  14. )
  15. from open_webui.config import (
  16. PINECONE_API_KEY,
  17. PINECONE_ENVIRONMENT,
  18. PINECONE_INDEX_NAME,
  19. PINECONE_DIMENSION,
  20. PINECONE_METRIC,
  21. PINECONE_CLOUD,
  22. )
  23. from open_webui.env import SRC_LOG_LEVELS
  24. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  25. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  26. log = logging.getLogger(__name__)
  27. log.setLevel(SRC_LOG_LEVELS["RAG"])
  28. class PineconeClient(VectorDBBase):
  29. def __init__(self):
  30. self.collection_prefix = "open-webui"
  31. # Validate required configuration
  32. self._validate_config()
  33. # Store configuration values
  34. self.api_key = PINECONE_API_KEY
  35. self.environment = PINECONE_ENVIRONMENT
  36. self.index_name = PINECONE_INDEX_NAME
  37. self.dimension = PINECONE_DIMENSION
  38. self.metric = PINECONE_METRIC
  39. self.cloud = PINECONE_CLOUD
  40. # Initialize Pinecone gRPC client for improved performance
  41. self.client = PineconeGRPC(api_key=self.api_key, environment=self.environment, cloud=self.cloud)
  42. # Persistent executor for batch operations
  43. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
  44. # Create index if it doesn't exist
  45. self._initialize_index()
  46. def _validate_config(self) -> None:
  47. """Validate that all required configuration variables are set."""
  48. missing_vars = []
  49. if not PINECONE_API_KEY:
  50. missing_vars.append("PINECONE_API_KEY")
  51. if not PINECONE_ENVIRONMENT:
  52. missing_vars.append("PINECONE_ENVIRONMENT")
  53. if not PINECONE_INDEX_NAME:
  54. missing_vars.append("PINECONE_INDEX_NAME")
  55. if not PINECONE_DIMENSION:
  56. missing_vars.append("PINECONE_DIMENSION")
  57. if not PINECONE_CLOUD:
  58. missing_vars.append("PINECONE_CLOUD")
  59. if missing_vars:
  60. raise ValueError(
  61. f"Required configuration missing: {', '.join(missing_vars)}"
  62. )
  63. def _initialize_index(self) -> None:
  64. """Initialize the Pinecone index."""
  65. try:
  66. # Check if index exists
  67. if self.index_name not in self.client.list_indexes().names():
  68. log.info(f"Creating Pinecone index '{self.index_name}'...")
  69. self.client.create_index(
  70. name=self.index_name,
  71. dimension=self.dimension,
  72. metric=self.metric,
  73. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  74. )
  75. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  76. else:
  77. log.info(f"Using existing Pinecone index '{self.index_name}'")
  78. # Connect to the index
  79. self.index = self.client.Index(self.index_name)
  80. except Exception as e:
  81. log.error(f"Failed to initialize Pinecone index: {e}")
  82. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  83. def _create_points(
  84. self, items: List[VectorItem], collection_name_with_prefix: str
  85. ) -> List[Dict[str, Any]]:
  86. """Convert VectorItem objects to Pinecone point format."""
  87. points = []
  88. for item in items:
  89. # Start with any existing metadata or an empty dict
  90. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  91. # Add text to metadata if available
  92. if "text" in item:
  93. metadata["text"] = item["text"]
  94. # Always add collection_name to metadata for filtering
  95. metadata["collection_name"] = collection_name_with_prefix
  96. point = {
  97. "id": item["id"],
  98. "values": item["vector"],
  99. "metadata": metadata,
  100. }
  101. points.append(point)
  102. return points
  103. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  104. """Get the collection name with prefix."""
  105. return f"{self.collection_prefix}_{collection_name}"
  106. def _normalize_distance(self, score: float) -> float:
  107. """Normalize distance score based on the metric used."""
  108. if self.metric.lower() == "cosine":
  109. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  110. return (score + 1.0) / 2.0
  111. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  112. # These are already suitable for ranking (smaller is better for Euclidean)
  113. return score
  114. else:
  115. # For other metrics, use as is
  116. return score
  117. def _result_to_get_result(self, matches: list) -> GetResult:
  118. """Convert Pinecone matches to GetResult format."""
  119. ids = []
  120. documents = []
  121. metadatas = []
  122. for match in matches:
  123. metadata = match.get("metadata", {})
  124. ids.append(match["id"])
  125. documents.append(metadata.get("text", ""))
  126. metadatas.append(metadata)
  127. return GetResult(
  128. **{
  129. "ids": [ids],
  130. "documents": [documents],
  131. "metadatas": [metadatas],
  132. }
  133. )
  134. def has_collection(self, collection_name: str) -> bool:
  135. """Check if a collection exists by searching for at least one item."""
  136. collection_name_with_prefix = self._get_collection_name_with_prefix(
  137. collection_name
  138. )
  139. try:
  140. # Search for at least 1 item with this collection name in metadata
  141. response = self.index.query(
  142. vector=[0.0] * self.dimension, # dummy vector
  143. top_k=1,
  144. filter={"collection_name": collection_name_with_prefix},
  145. include_metadata=False,
  146. )
  147. return len(response.matches) > 0
  148. except Exception as e:
  149. log.exception(
  150. f"Error checking collection '{collection_name_with_prefix}': {e}"
  151. )
  152. return False
  153. def delete_collection(self, collection_name: str) -> None:
  154. """Delete a collection by removing all vectors with the collection name in metadata."""
  155. collection_name_with_prefix = self._get_collection_name_with_prefix(
  156. collection_name
  157. )
  158. try:
  159. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  160. log.info(
  161. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  162. )
  163. except Exception as e:
  164. log.warning(
  165. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  166. )
  167. raise
  168. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  169. """Insert vectors into a collection."""
  170. if not items:
  171. log.warning("No items to insert")
  172. return
  173. start_time = time.time()
  174. collection_name_with_prefix = self._get_collection_name_with_prefix(
  175. collection_name
  176. )
  177. points = self._create_points(items, collection_name_with_prefix)
  178. # Parallelize batch inserts for performance
  179. executor = self._executor
  180. futures = []
  181. for i in range(0, len(points), BATCH_SIZE):
  182. batch = points[i : i + BATCH_SIZE]
  183. futures.append(executor.submit(self.index.upsert, vectors=batch))
  184. for future in concurrent.futures.as_completed(futures):
  185. try:
  186. future.result()
  187. except Exception as e:
  188. log.error(f"Error inserting batch: {e}")
  189. raise
  190. elapsed = time.time() - start_time
  191. log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
  192. log.info(f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
  193. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  194. """Upsert (insert or update) vectors into a collection."""
  195. if not items:
  196. log.warning("No items to upsert")
  197. return
  198. start_time = time.time()
  199. collection_name_with_prefix = self._get_collection_name_with_prefix(
  200. collection_name
  201. )
  202. points = self._create_points(items, collection_name_with_prefix)
  203. # Parallelize batch upserts for performance
  204. executor = self._executor
  205. futures = []
  206. for i in range(0, len(points), BATCH_SIZE):
  207. batch = points[i : i + BATCH_SIZE]
  208. futures.append(executor.submit(self.index.upsert, vectors=batch))
  209. for future in concurrent.futures.as_completed(futures):
  210. try:
  211. future.result()
  212. except Exception as e:
  213. log.error(f"Error upserting batch: {e}")
  214. raise
  215. elapsed = time.time() - start_time
  216. log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
  217. log.info(f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
  218. async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  219. """Async version of insert using asyncio and run_in_executor for improved performance."""
  220. if not items:
  221. log.warning("No items to insert")
  222. return
  223. collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
  224. points = self._create_points(items, collection_name_with_prefix)
  225. # Create batches
  226. batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
  227. loop = asyncio.get_event_loop()
  228. tasks = [
  229. loop.run_in_executor(
  230. None,
  231. functools.partial(self.index.upsert, vectors=batch)
  232. )
  233. for batch in batches
  234. ]
  235. results = await asyncio.gather(*tasks, return_exceptions=True)
  236. for result in results:
  237. if isinstance(result, Exception):
  238. log.error(f"Error in async insert batch: {result}")
  239. raise result
  240. log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
  241. async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  242. """Async version of upsert using asyncio and run_in_executor for improved performance."""
  243. if not items:
  244. log.warning("No items to upsert")
  245. return
  246. collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
  247. points = self._create_points(items, collection_name_with_prefix)
  248. # Create batches
  249. batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
  250. loop = asyncio.get_event_loop()
  251. tasks = [
  252. loop.run_in_executor(
  253. None,
  254. functools.partial(self.index.upsert, vectors=batch)
  255. )
  256. for batch in batches
  257. ]
  258. results = await asyncio.gather(*tasks, return_exceptions=True)
  259. for result in results:
  260. if isinstance(result, Exception):
  261. log.error(f"Error in async upsert batch: {result}")
  262. raise result
  263. log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
  264. def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  265. """Perform a streaming upsert over gRPC for performance testing."""
  266. if not items:
  267. log.warning("No items to upsert via streaming")
  268. return
  269. collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
  270. points = self._create_points(items, collection_name_with_prefix)
  271. # Open a streaming upsert channel
  272. stream = self.index.streaming_upsert()
  273. try:
  274. for point in points:
  275. # send each point over the stream
  276. stream.send(point)
  277. # close the stream to finalize
  278. stream.close()
  279. log.info(f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'")
  280. except Exception as e:
  281. log.error(f"Error during streaming upsert: {e}")
  282. raise
  283. def search(
  284. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  285. ) -> Optional[SearchResult]:
  286. """Search for similar vectors in a collection."""
  287. if not vectors or not vectors[0]:
  288. log.warning("No vectors provided for search")
  289. return None
  290. collection_name_with_prefix = self._get_collection_name_with_prefix(
  291. collection_name
  292. )
  293. if limit is None or limit <= 0:
  294. limit = NO_LIMIT
  295. try:
  296. # Search using the first vector (assuming this is the intended behavior)
  297. query_vector = vectors[0]
  298. # Perform the search
  299. query_response = self.index.query(
  300. vector=query_vector,
  301. top_k=limit,
  302. include_metadata=True,
  303. filter={"collection_name": collection_name_with_prefix},
  304. )
  305. if not query_response.matches:
  306. # Return empty result if no matches
  307. return SearchResult(
  308. ids=[[]],
  309. documents=[[]],
  310. metadatas=[[]],
  311. distances=[[]],
  312. )
  313. # Convert to GetResult format
  314. get_result = self._result_to_get_result(query_response.matches)
  315. # Calculate normalized distances based on metric
  316. distances = [
  317. [
  318. self._normalize_distance(match.score)
  319. for match in query_response.matches
  320. ]
  321. ]
  322. return SearchResult(
  323. ids=get_result.ids,
  324. documents=get_result.documents,
  325. metadatas=get_result.metadatas,
  326. distances=distances,
  327. )
  328. except Exception as e:
  329. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  330. return None
  331. def query(
  332. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  333. ) -> Optional[GetResult]:
  334. """Query vectors by metadata filter."""
  335. collection_name_with_prefix = self._get_collection_name_with_prefix(
  336. collection_name
  337. )
  338. if limit is None or limit <= 0:
  339. limit = NO_LIMIT
  340. try:
  341. # Create a zero vector for the dimension as Pinecone requires a vector
  342. zero_vector = [0.0] * self.dimension
  343. # Combine user filter with collection_name
  344. pinecone_filter = {"collection_name": collection_name_with_prefix}
  345. if filter:
  346. pinecone_filter.update(filter)
  347. # Perform metadata-only query
  348. query_response = self.index.query(
  349. vector=zero_vector,
  350. filter=pinecone_filter,
  351. top_k=limit,
  352. include_metadata=True,
  353. )
  354. return self._result_to_get_result(query_response.matches)
  355. except Exception as e:
  356. log.error(f"Error querying collection '{collection_name}': {e}")
  357. return None
  358. def get(self, collection_name: str) -> Optional[GetResult]:
  359. """Get all vectors in a collection."""
  360. collection_name_with_prefix = self._get_collection_name_with_prefix(
  361. collection_name
  362. )
  363. try:
  364. # Use a zero vector for fetching all entries
  365. zero_vector = [0.0] * self.dimension
  366. # Add filter to only get vectors for this collection
  367. query_response = self.index.query(
  368. vector=zero_vector,
  369. top_k=NO_LIMIT,
  370. include_metadata=True,
  371. filter={"collection_name": collection_name_with_prefix},
  372. )
  373. return self._result_to_get_result(query_response.matches)
  374. except Exception as e:
  375. log.error(f"Error getting collection '{collection_name}': {e}")
  376. return None
  377. def delete(
  378. self,
  379. collection_name: str,
  380. ids: Optional[List[str]] = None,
  381. filter: Optional[Dict] = None,
  382. ) -> None:
  383. """Delete vectors by IDs or filter."""
  384. collection_name_with_prefix = self._get_collection_name_with_prefix(
  385. collection_name
  386. )
  387. try:
  388. if ids:
  389. # Delete by IDs (in batches for large deletions)
  390. for i in range(0, len(ids), BATCH_SIZE):
  391. batch_ids = ids[i : i + BATCH_SIZE]
  392. # Note: When deleting by ID, we can't filter by collection_name
  393. # This is a limitation of Pinecone - be careful with ID uniqueness
  394. self.index.delete(ids=batch_ids)
  395. log.debug(
  396. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  397. )
  398. log.info(
  399. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  400. )
  401. elif filter:
  402. # Combine user filter with collection_name
  403. pinecone_filter = {"collection_name": collection_name_with_prefix}
  404. if filter:
  405. pinecone_filter.update(filter)
  406. # Delete by metadata filter
  407. self.index.delete(filter=pinecone_filter)
  408. log.info(
  409. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  410. )
  411. else:
  412. log.warning("No ids or filter provided for delete operation")
  413. except Exception as e:
  414. log.error(f"Error deleting from collection '{collection_name}': {e}")
  415. raise
  416. def reset(self) -> None:
  417. """Reset the database by deleting all collections."""
  418. try:
  419. self.index.delete(delete_all=True)
  420. log.info("All vectors successfully deleted from the index.")
  421. except Exception as e:
  422. log.error(f"Failed to reset Pinecone index: {e}")
  423. raise
  424. def close(self):
  425. """Shut down the thread pool."""
  426. self._executor.shutdown(wait=True)