pinecone.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import time # for measuring elapsed time
  4. from pinecone import Pinecone, ServerlessSpec
  5. import asyncio # for async upserts
  6. import functools # for partial binding in async tasks
  7. import concurrent.futures # for parallel batch upserts
  8. from open_webui.retrieval.vector.main import (
  9. VectorDBBase,
  10. VectorItem,
  11. SearchResult,
  12. GetResult,
  13. )
  14. from open_webui.config import (
  15. PINECONE_API_KEY,
  16. PINECONE_ENVIRONMENT,
  17. PINECONE_INDEX_NAME,
  18. PINECONE_DIMENSION,
  19. PINECONE_METRIC,
  20. PINECONE_CLOUD,
  21. )
  22. from open_webui.env import SRC_LOG_LEVELS
  23. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  24. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. class PineconeClient(VectorDBBase):
  28. def __init__(self):
  29. self.collection_prefix = "open-webui"
  30. # Validate required configuration
  31. self._validate_config()
  32. # Store configuration values
  33. self.api_key = PINECONE_API_KEY
  34. self.environment = PINECONE_ENVIRONMENT
  35. self.index_name = PINECONE_INDEX_NAME
  36. self.dimension = PINECONE_DIMENSION
  37. self.metric = PINECONE_METRIC
  38. self.cloud = PINECONE_CLOUD
  39. # Initialize Pinecone client for improved performance
  40. self.client = Pinecone(api_key=self.api_key)
  41. # Persistent executor for batch operations
  42. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
  43. # Create index if it doesn't exist
  44. self._initialize_index()
  45. def _validate_config(self) -> None:
  46. """Validate that all required configuration variables are set."""
  47. missing_vars = []
  48. if not PINECONE_API_KEY:
  49. missing_vars.append("PINECONE_API_KEY")
  50. if not PINECONE_ENVIRONMENT:
  51. missing_vars.append("PINECONE_ENVIRONMENT")
  52. if not PINECONE_INDEX_NAME:
  53. missing_vars.append("PINECONE_INDEX_NAME")
  54. if not PINECONE_DIMENSION:
  55. missing_vars.append("PINECONE_DIMENSION")
  56. if not PINECONE_CLOUD:
  57. missing_vars.append("PINECONE_CLOUD")
  58. if missing_vars:
  59. raise ValueError(
  60. f"Required configuration missing: {', '.join(missing_vars)}"
  61. )
  62. def _initialize_index(self) -> None:
  63. """Initialize the Pinecone index."""
  64. try:
  65. # Check if index exists
  66. if self.index_name not in self.client.list_indexes().names():
  67. log.info(f"Creating Pinecone index '{self.index_name}'...")
  68. self.client.create_index(
  69. name=self.index_name,
  70. dimension=self.dimension,
  71. metric=self.metric,
  72. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  73. )
  74. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  75. else:
  76. log.info(f"Using existing Pinecone index '{self.index_name}'")
  77. # Connect to the index
  78. self.index = self.client.Index(self.index_name)
  79. except Exception as e:
  80. log.error(f"Failed to initialize Pinecone index: {e}")
  81. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  82. def _create_points(
  83. self, items: List[VectorItem], collection_name_with_prefix: str
  84. ) -> List[Dict[str, Any]]:
  85. """Convert VectorItem objects to Pinecone point format."""
  86. points = []
  87. for item in items:
  88. # Start with any existing metadata or an empty dict
  89. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  90. # Add text to metadata if available
  91. if "text" in item:
  92. metadata["text"] = item["text"]
  93. # Always add collection_name to metadata for filtering
  94. metadata["collection_name"] = collection_name_with_prefix
  95. point = {
  96. "id": item["id"],
  97. "values": item["vector"],
  98. "metadata": metadata,
  99. }
  100. points.append(point)
  101. return points
  102. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  103. """Get the collection name with prefix."""
  104. return f"{self.collection_prefix}_{collection_name}"
  105. def _normalize_distance(self, score: float) -> float:
  106. """Normalize distance score based on the metric used."""
  107. if self.metric.lower() == "cosine":
  108. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  109. return (score + 1.0) / 2.0
  110. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  111. # These are already suitable for ranking (smaller is better for Euclidean)
  112. return score
  113. else:
  114. # For other metrics, use as is
  115. return score
  116. def _result_to_get_result(self, matches: list) -> GetResult:
  117. """Convert Pinecone matches to GetResult format."""
  118. ids = []
  119. documents = []
  120. metadatas = []
  121. for match in matches:
  122. metadata = getattr(match, "metadata", {}) or {}
  123. ids.append(match.id if hasattr(match, "id") else match["id"])
  124. documents.append(metadata.get("text", ""))
  125. metadatas.append(metadata)
  126. return GetResult(
  127. **{
  128. "ids": [ids],
  129. "documents": [documents],
  130. "metadatas": [metadatas],
  131. }
  132. )
  133. def has_collection(self, collection_name: str) -> bool:
  134. """Check if a collection exists by searching for at least one item."""
  135. collection_name_with_prefix = self._get_collection_name_with_prefix(
  136. collection_name
  137. )
  138. try:
  139. # Search for at least 1 item with this collection name in metadata
  140. response = self.index.query(
  141. vector=[0.0] * self.dimension, # dummy vector
  142. top_k=1,
  143. filter={"collection_name": collection_name_with_prefix},
  144. include_metadata=False,
  145. )
  146. matches = getattr(response, "matches", []) or []
  147. return len(matches) > 0
  148. except Exception as e:
  149. log.exception(
  150. f"Error checking collection '{collection_name_with_prefix}': {e}"
  151. )
  152. return False
  153. def delete_collection(self, collection_name: str) -> None:
  154. """Delete a collection by removing all vectors with the collection name in metadata."""
  155. collection_name_with_prefix = self._get_collection_name_with_prefix(
  156. collection_name
  157. )
  158. try:
  159. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  160. log.info(
  161. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  162. )
  163. except Exception as e:
  164. log.warning(
  165. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  166. )
  167. raise
  168. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  169. """Insert vectors into a collection."""
  170. if not items:
  171. log.warning("No items to insert")
  172. return
  173. start_time = time.time()
  174. collection_name_with_prefix = self._get_collection_name_with_prefix(
  175. collection_name
  176. )
  177. points = self._create_points(items, collection_name_with_prefix)
  178. # Parallelize batch inserts for performance
  179. executor = self._executor
  180. futures = []
  181. for i in range(0, len(points), BATCH_SIZE):
  182. batch = points[i : i + BATCH_SIZE]
  183. futures.append(executor.submit(self.index.upsert, vectors=batch))
  184. for future in concurrent.futures.as_completed(futures):
  185. try:
  186. future.result()
  187. except Exception as e:
  188. log.error(f"Error inserting batch: {e}")
  189. raise
  190. elapsed = time.time() - start_time
  191. log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
  192. log.info(
  193. f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  194. )
  195. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  196. """Upsert (insert or update) vectors into a collection."""
  197. if not items:
  198. log.warning("No items to upsert")
  199. return
  200. start_time = time.time()
  201. collection_name_with_prefix = self._get_collection_name_with_prefix(
  202. collection_name
  203. )
  204. points = self._create_points(items, collection_name_with_prefix)
  205. # Parallelize batch upserts for performance
  206. executor = self._executor
  207. futures = []
  208. for i in range(0, len(points), BATCH_SIZE):
  209. batch = points[i : i + BATCH_SIZE]
  210. futures.append(executor.submit(self.index.upsert, vectors=batch))
  211. for future in concurrent.futures.as_completed(futures):
  212. try:
  213. future.result()
  214. except Exception as e:
  215. log.error(f"Error upserting batch: {e}")
  216. raise
  217. elapsed = time.time() - start_time
  218. log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
  219. log.info(
  220. f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  221. )
  222. async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  223. """Async version of insert using asyncio and run_in_executor for improved performance."""
  224. if not items:
  225. log.warning("No items to insert")
  226. return
  227. collection_name_with_prefix = self._get_collection_name_with_prefix(
  228. collection_name
  229. )
  230. points = self._create_points(items, collection_name_with_prefix)
  231. # Create batches
  232. batches = [
  233. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  234. ]
  235. loop = asyncio.get_event_loop()
  236. tasks = [
  237. loop.run_in_executor(
  238. None, functools.partial(self.index.upsert, vectors=batch)
  239. )
  240. for batch in batches
  241. ]
  242. results = await asyncio.gather(*tasks, return_exceptions=True)
  243. for result in results:
  244. if isinstance(result, Exception):
  245. log.error(f"Error in async insert batch: {result}")
  246. raise result
  247. log.info(
  248. f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  249. )
  250. async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  251. """Async version of upsert using asyncio and run_in_executor for improved performance."""
  252. if not items:
  253. log.warning("No items to upsert")
  254. return
  255. collection_name_with_prefix = self._get_collection_name_with_prefix(
  256. collection_name
  257. )
  258. points = self._create_points(items, collection_name_with_prefix)
  259. # Create batches
  260. batches = [
  261. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  262. ]
  263. loop = asyncio.get_event_loop()
  264. tasks = [
  265. loop.run_in_executor(
  266. None, functools.partial(self.index.upsert, vectors=batch)
  267. )
  268. for batch in batches
  269. ]
  270. results = await asyncio.gather(*tasks, return_exceptions=True)
  271. for result in results:
  272. if isinstance(result, Exception):
  273. log.error(f"Error in async upsert batch: {result}")
  274. raise result
  275. log.info(
  276. f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  277. )
  278. def search(
  279. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  280. ) -> Optional[SearchResult]:
  281. """Search for similar vectors in a collection."""
  282. if not vectors or not vectors[0]:
  283. log.warning("No vectors provided for search")
  284. return None
  285. collection_name_with_prefix = self._get_collection_name_with_prefix(
  286. collection_name
  287. )
  288. if limit is None or limit <= 0:
  289. limit = NO_LIMIT
  290. try:
  291. # Search using the first vector (assuming this is the intended behavior)
  292. query_vector = vectors[0]
  293. # Perform the search
  294. query_response = self.index.query(
  295. vector=query_vector,
  296. top_k=limit,
  297. include_metadata=True,
  298. filter={"collection_name": collection_name_with_prefix},
  299. )
  300. matches = getattr(query_response, "matches", []) or []
  301. if not matches:
  302. # Return empty result if no matches
  303. return SearchResult(
  304. ids=[[]],
  305. documents=[[]],
  306. metadatas=[[]],
  307. distances=[[]],
  308. )
  309. # Convert to GetResult format
  310. get_result = self._result_to_get_result(matches)
  311. # Calculate normalized distances based on metric
  312. distances = [
  313. [
  314. self._normalize_distance(getattr(match, "score", 0.0))
  315. for match in matches
  316. ]
  317. ]
  318. return SearchResult(
  319. ids=get_result.ids,
  320. documents=get_result.documents,
  321. metadatas=get_result.metadatas,
  322. distances=distances,
  323. )
  324. except Exception as e:
  325. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  326. return None
  327. def query(
  328. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  329. ) -> Optional[GetResult]:
  330. """Query vectors by metadata filter."""
  331. collection_name_with_prefix = self._get_collection_name_with_prefix(
  332. collection_name
  333. )
  334. if limit is None or limit <= 0:
  335. limit = NO_LIMIT
  336. try:
  337. # Create a zero vector for the dimension as Pinecone requires a vector
  338. zero_vector = [0.0] * self.dimension
  339. # Combine user filter with collection_name
  340. pinecone_filter = {"collection_name": collection_name_with_prefix}
  341. if filter:
  342. pinecone_filter.update(filter)
  343. # Perform metadata-only query
  344. query_response = self.index.query(
  345. vector=zero_vector,
  346. filter=pinecone_filter,
  347. top_k=limit,
  348. include_metadata=True,
  349. )
  350. matches = getattr(query_response, "matches", []) or []
  351. return self._result_to_get_result(matches)
  352. except Exception as e:
  353. log.error(f"Error querying collection '{collection_name}': {e}")
  354. return None
  355. def get(self, collection_name: str) -> Optional[GetResult]:
  356. """Get all vectors in a collection."""
  357. collection_name_with_prefix = self._get_collection_name_with_prefix(
  358. collection_name
  359. )
  360. try:
  361. # Use a zero vector for fetching all entries
  362. zero_vector = [0.0] * self.dimension
  363. # Add filter to only get vectors for this collection
  364. query_response = self.index.query(
  365. vector=zero_vector,
  366. top_k=NO_LIMIT,
  367. include_metadata=True,
  368. filter={"collection_name": collection_name_with_prefix},
  369. )
  370. matches = getattr(query_response, "matches", []) or []
  371. return self._result_to_get_result(matches)
  372. except Exception as e:
  373. log.error(f"Error getting collection '{collection_name}': {e}")
  374. return None
  375. def delete(
  376. self,
  377. collection_name: str,
  378. ids: Optional[List[str]] = None,
  379. filter: Optional[Dict] = None,
  380. ) -> None:
  381. """Delete vectors by IDs or filter."""
  382. collection_name_with_prefix = self._get_collection_name_with_prefix(
  383. collection_name
  384. )
  385. try:
  386. if ids:
  387. # Delete by IDs (in batches for large deletions)
  388. for i in range(0, len(ids), BATCH_SIZE):
  389. batch_ids = ids[i : i + BATCH_SIZE]
  390. # Note: When deleting by ID, we can't filter by collection_name
  391. # This is a limitation of Pinecone - be careful with ID uniqueness
  392. self.index.delete(ids=batch_ids)
  393. log.debug(
  394. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  395. )
  396. log.info(
  397. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  398. )
  399. elif filter:
  400. # Combine user filter with collection_name
  401. pinecone_filter = {"collection_name": collection_name_with_prefix}
  402. if filter:
  403. pinecone_filter.update(filter)
  404. # Delete by metadata filter
  405. self.index.delete(filter=pinecone_filter)
  406. log.info(
  407. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  408. )
  409. else:
  410. log.warning("No ids or filter provided for delete operation")
  411. except Exception as e:
  412. log.error(f"Error deleting from collection '{collection_name}': {e}")
  413. raise
  414. def reset(self) -> None:
  415. """Reset the database by deleting all collections."""
  416. try:
  417. self.index.delete(delete_all=True)
  418. log.info("All vectors successfully deleted from the index.")
  419. except Exception as e:
  420. log.error(f"Failed to reset Pinecone index: {e}")
  421. raise
  422. def close(self):
  423. """Shut down resources."""
  424. try:
  425. # The new Pinecone client doesn't need explicit closing
  426. pass
  427. except Exception as e:
  428. log.warning(f"Failed to clean up Pinecone resources: {e}")
  429. self._executor.shutdown(wait=True)
  430. def __enter__(self):
  431. """Enter context manager."""
  432. return self
  433. def __exit__(self, exc_type, exc_val, exc_tb):
  434. """Exit context manager, ensuring resources are cleaned up."""
  435. self.close()