pinecone.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. from pinecone import Pinecone, ServerlessSpec
  4. from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
  5. from open_webui.config import (
  6. PINECONE_API_KEY,
  7. PINECONE_ENVIRONMENT,
  8. PINECONE_INDEX_NAME,
  9. PINECONE_DIMENSION,
  10. PINECONE_METRIC,
  11. PINECONE_CLOUD,
  12. )
  13. from open_webui.env import SRC_LOG_LEVELS
  14. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  15. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  16. log = logging.getLogger(__name__)
  17. log.setLevel(SRC_LOG_LEVELS["RAG"])
  18. class PineconeClient:
  19. def __init__(self):
  20. self.collection_prefix = "open-webui"
  21. # Validate required configuration
  22. self._validate_config()
  23. # Store configuration values
  24. self.api_key = PINECONE_API_KEY
  25. self.environment = PINECONE_ENVIRONMENT
  26. self.index_name = PINECONE_INDEX_NAME
  27. self.dimension = PINECONE_DIMENSION
  28. self.metric = PINECONE_METRIC
  29. self.cloud = PINECONE_CLOUD
  30. # Initialize Pinecone client
  31. self.client = Pinecone(api_key=self.api_key)
  32. # Create index if it doesn't exist
  33. self._initialize_index()
  34. def _validate_config(self) -> None:
  35. """Validate that all required configuration variables are set."""
  36. missing_vars = []
  37. if not PINECONE_API_KEY:
  38. missing_vars.append("PINECONE_API_KEY")
  39. if not PINECONE_ENVIRONMENT:
  40. missing_vars.append("PINECONE_ENVIRONMENT")
  41. if not PINECONE_INDEX_NAME:
  42. missing_vars.append("PINECONE_INDEX_NAME")
  43. if not PINECONE_DIMENSION:
  44. missing_vars.append("PINECONE_DIMENSION")
  45. if not PINECONE_CLOUD:
  46. missing_vars.append("PINECONE_CLOUD")
  47. if missing_vars:
  48. raise ValueError(
  49. f"Required configuration missing: {', '.join(missing_vars)}"
  50. )
  51. def _initialize_index(self) -> None:
  52. """Initialize the Pinecone index."""
  53. try:
  54. # Check if index exists
  55. if self.index_name not in self.client.list_indexes().names():
  56. log.info(f"Creating Pinecone index '{self.index_name}'...")
  57. self.client.create_index(
  58. name=self.index_name,
  59. dimension=self.dimension,
  60. metric=self.metric,
  61. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  62. )
  63. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  64. else:
  65. log.info(f"Using existing Pinecone index '{self.index_name}'")
  66. # Connect to the index
  67. self.index = self.client.Index(self.index_name)
  68. except Exception as e:
  69. log.error(f"Failed to initialize Pinecone index: {e}")
  70. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  71. def _create_points(
  72. self, items: List[VectorItem], collection_name_with_prefix: str
  73. ) -> List[Dict[str, Any]]:
  74. """Convert VectorItem objects to Pinecone point format."""
  75. points = []
  76. for item in items:
  77. # Start with any existing metadata or an empty dict
  78. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  79. # Add text to metadata if available
  80. if "text" in item:
  81. metadata["text"] = item["text"]
  82. # Always add collection_name to metadata for filtering
  83. metadata["collection_name"] = collection_name_with_prefix
  84. point = {
  85. "id": item["id"],
  86. "values": item["vector"],
  87. "metadata": metadata,
  88. }
  89. points.append(point)
  90. return points
  91. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  92. """Get the collection name with prefix."""
  93. return f"{self.collection_prefix}_{collection_name}"
  94. def _normalize_distance(self, score: float) -> float:
  95. """Normalize distance score based on the metric used."""
  96. if self.metric.lower() == "cosine":
  97. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  98. return (score + 1.0) / 2.0
  99. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  100. # These are already suitable for ranking (smaller is better for Euclidean)
  101. return score
  102. else:
  103. # For other metrics, use as is
  104. return score
  105. def _result_to_get_result(self, matches: list) -> GetResult:
  106. """Convert Pinecone matches to GetResult format."""
  107. ids = []
  108. documents = []
  109. metadatas = []
  110. for match in matches:
  111. metadata = match.get("metadata", {})
  112. ids.append(match["id"])
  113. documents.append(metadata.get("text", ""))
  114. metadatas.append(metadata)
  115. return GetResult(
  116. **{
  117. "ids": [ids],
  118. "documents": [documents],
  119. "metadatas": [metadatas],
  120. }
  121. )
  122. def has_collection(self, collection_name: str) -> bool:
  123. """Check if a collection exists by searching for at least one item."""
  124. collection_name_with_prefix = self._get_collection_name_with_prefix(
  125. collection_name
  126. )
  127. try:
  128. # Search for at least 1 item with this collection name in metadata
  129. response = self.index.query(
  130. vector=[0.0] * self.dimension, # dummy vector
  131. top_k=1,
  132. filter={"collection_name": collection_name_with_prefix},
  133. include_metadata=False,
  134. )
  135. return len(response.matches) > 0
  136. except Exception as e:
  137. log.exception(
  138. f"Error checking collection '{collection_name_with_prefix}': {e}"
  139. )
  140. return False
  141. def delete_collection(self, collection_name: str) -> None:
  142. """Delete a collection by removing all vectors with the collection name in metadata."""
  143. collection_name_with_prefix = self._get_collection_name_with_prefix(
  144. collection_name
  145. )
  146. try:
  147. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  148. log.info(
  149. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  150. )
  151. except Exception as e:
  152. log.warning(
  153. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  154. )
  155. raise
  156. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  157. """Insert vectors into a collection."""
  158. if not items:
  159. log.warning("No items to insert")
  160. return
  161. collection_name_with_prefix = self._get_collection_name_with_prefix(
  162. collection_name
  163. )
  164. points = self._create_points(items, collection_name_with_prefix)
  165. # Insert in batches for better performance and reliability
  166. for i in range(0, len(points), BATCH_SIZE):
  167. batch = points[i : i + BATCH_SIZE]
  168. try:
  169. self.index.upsert(vectors=batch)
  170. log.debug(
  171. f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
  172. )
  173. except Exception as e:
  174. log.error(
  175. f"Error inserting batch into '{collection_name_with_prefix}': {e}"
  176. )
  177. raise
  178. log.info(
  179. f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
  180. )
  181. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  182. """Upsert (insert or update) vectors into a collection."""
  183. if not items:
  184. log.warning("No items to upsert")
  185. return
  186. collection_name_with_prefix = self._get_collection_name_with_prefix(
  187. collection_name
  188. )
  189. points = self._create_points(items, collection_name_with_prefix)
  190. # Upsert in batches
  191. for i in range(0, len(points), BATCH_SIZE):
  192. batch = points[i : i + BATCH_SIZE]
  193. try:
  194. self.index.upsert(vectors=batch)
  195. log.debug(
  196. f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
  197. )
  198. except Exception as e:
  199. log.error(
  200. f"Error upserting batch into '{collection_name_with_prefix}': {e}"
  201. )
  202. raise
  203. log.info(
  204. f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
  205. )
  206. def search(
  207. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  208. ) -> Optional[SearchResult]:
  209. """Search for similar vectors in a collection."""
  210. if not vectors or not vectors[0]:
  211. log.warning("No vectors provided for search")
  212. return None
  213. collection_name_with_prefix = self._get_collection_name_with_prefix(
  214. collection_name
  215. )
  216. if limit is None or limit <= 0:
  217. limit = NO_LIMIT
  218. try:
  219. # Search using the first vector (assuming this is the intended behavior)
  220. query_vector = vectors[0]
  221. # Perform the search
  222. query_response = self.index.query(
  223. vector=query_vector,
  224. top_k=limit,
  225. include_metadata=True,
  226. filter={"collection_name": collection_name_with_prefix},
  227. )
  228. if not query_response.matches:
  229. # Return empty result if no matches
  230. return SearchResult(
  231. ids=[[]],
  232. documents=[[]],
  233. metadatas=[[]],
  234. distances=[[]],
  235. )
  236. # Convert to GetResult format
  237. get_result = self._result_to_get_result(query_response.matches)
  238. # Calculate normalized distances based on metric
  239. distances = [
  240. [
  241. self._normalize_distance(match.score)
  242. for match in query_response.matches
  243. ]
  244. ]
  245. return SearchResult(
  246. ids=get_result.ids,
  247. documents=get_result.documents,
  248. metadatas=get_result.metadatas,
  249. distances=distances,
  250. )
  251. except Exception as e:
  252. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  253. return None
  254. def query(
  255. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  256. ) -> Optional[GetResult]:
  257. """Query vectors by metadata filter."""
  258. collection_name_with_prefix = self._get_collection_name_with_prefix(
  259. collection_name
  260. )
  261. if limit is None or limit <= 0:
  262. limit = NO_LIMIT
  263. try:
  264. # Create a zero vector for the dimension as Pinecone requires a vector
  265. zero_vector = [0.0] * self.dimension
  266. # Combine user filter with collection_name
  267. pinecone_filter = {"collection_name": collection_name_with_prefix}
  268. if filter:
  269. pinecone_filter.update(filter)
  270. # Perform metadata-only query
  271. query_response = self.index.query(
  272. vector=zero_vector,
  273. filter=pinecone_filter,
  274. top_k=limit,
  275. include_metadata=True,
  276. )
  277. return self._result_to_get_result(query_response.matches)
  278. except Exception as e:
  279. log.error(f"Error querying collection '{collection_name}': {e}")
  280. return None
  281. def get(self, collection_name: str) -> Optional[GetResult]:
  282. """Get all vectors in a collection."""
  283. collection_name_with_prefix = self._get_collection_name_with_prefix(
  284. collection_name
  285. )
  286. try:
  287. # Use a zero vector for fetching all entries
  288. zero_vector = [0.0] * self.dimension
  289. # Add filter to only get vectors for this collection
  290. query_response = self.index.query(
  291. vector=zero_vector,
  292. top_k=NO_LIMIT,
  293. include_metadata=True,
  294. filter={"collection_name": collection_name_with_prefix},
  295. )
  296. return self._result_to_get_result(query_response.matches)
  297. except Exception as e:
  298. log.error(f"Error getting collection '{collection_name}': {e}")
  299. return None
  300. def delete(
  301. self,
  302. collection_name: str,
  303. ids: Optional[List[str]] = None,
  304. filter: Optional[Dict] = None,
  305. ) -> None:
  306. """Delete vectors by IDs or filter."""
  307. collection_name_with_prefix = self._get_collection_name_with_prefix(
  308. collection_name
  309. )
  310. try:
  311. if ids:
  312. # Delete by IDs (in batches for large deletions)
  313. for i in range(0, len(ids), BATCH_SIZE):
  314. batch_ids = ids[i : i + BATCH_SIZE]
  315. # Note: When deleting by ID, we can't filter by collection_name
  316. # This is a limitation of Pinecone - be careful with ID uniqueness
  317. self.index.delete(ids=batch_ids)
  318. log.debug(
  319. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  320. )
  321. log.info(
  322. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  323. )
  324. elif filter:
  325. # Combine user filter with collection_name
  326. pinecone_filter = {"collection_name": collection_name_with_prefix}
  327. if filter:
  328. pinecone_filter.update(filter)
  329. # Delete by metadata filter
  330. self.index.delete(filter=pinecone_filter)
  331. log.info(
  332. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  333. )
  334. else:
  335. log.warning("No ids or filter provided for delete operation")
  336. except Exception as e:
  337. log.error(f"Error deleting from collection '{collection_name}': {e}")
  338. raise
  339. def reset(self) -> None:
  340. """Reset the database by deleting all collections."""
  341. try:
  342. self.index.delete(delete_all=True)
  343. log.info("All vectors successfully deleted from the index.")
  344. except Exception as e:
  345. log.error(f"Failed to reset Pinecone index: {e}")
  346. raise