pinecone.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. from pinecone import Pinecone, ServerlessSpec
  4. from open_webui.retrieval.vector.main import (
  5. VectorDBBase,
  6. VectorItem,
  7. SearchResult,
  8. GetResult,
  9. )
  10. from open_webui.config import (
  11. PINECONE_API_KEY,
  12. PINECONE_ENVIRONMENT,
  13. PINECONE_INDEX_NAME,
  14. PINECONE_DIMENSION,
  15. PINECONE_METRIC,
  16. PINECONE_CLOUD,
  17. )
  18. from open_webui.env import SRC_LOG_LEVELS
  19. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  20. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["RAG"])
  23. class PineconeClient(VectorDBBase):
  24. def __init__(self):
  25. self.collection_prefix = "open-webui"
  26. # Validate required configuration
  27. self._validate_config()
  28. # Store configuration values
  29. self.api_key = PINECONE_API_KEY
  30. self.environment = PINECONE_ENVIRONMENT
  31. self.index_name = PINECONE_INDEX_NAME
  32. self.dimension = PINECONE_DIMENSION
  33. self.metric = PINECONE_METRIC
  34. self.cloud = PINECONE_CLOUD
  35. # Initialize Pinecone client
  36. self.client = Pinecone(api_key=self.api_key)
  37. # Create index if it doesn't exist
  38. self._initialize_index()
  39. def _validate_config(self) -> None:
  40. """Validate that all required configuration variables are set."""
  41. missing_vars = []
  42. if not PINECONE_API_KEY:
  43. missing_vars.append("PINECONE_API_KEY")
  44. if not PINECONE_ENVIRONMENT:
  45. missing_vars.append("PINECONE_ENVIRONMENT")
  46. if not PINECONE_INDEX_NAME:
  47. missing_vars.append("PINECONE_INDEX_NAME")
  48. if not PINECONE_DIMENSION:
  49. missing_vars.append("PINECONE_DIMENSION")
  50. if not PINECONE_CLOUD:
  51. missing_vars.append("PINECONE_CLOUD")
  52. if missing_vars:
  53. raise ValueError(
  54. f"Required configuration missing: {', '.join(missing_vars)}"
  55. )
  56. def _initialize_index(self) -> None:
  57. """Initialize the Pinecone index."""
  58. try:
  59. # Check if index exists
  60. if self.index_name not in self.client.list_indexes().names():
  61. log.info(f"Creating Pinecone index '{self.index_name}'...")
  62. self.client.create_index(
  63. name=self.index_name,
  64. dimension=self.dimension,
  65. metric=self.metric,
  66. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  67. )
  68. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  69. else:
  70. log.info(f"Using existing Pinecone index '{self.index_name}'")
  71. # Connect to the index
  72. self.index = self.client.Index(self.index_name)
  73. except Exception as e:
  74. log.error(f"Failed to initialize Pinecone index: {e}")
  75. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  76. def _create_points(
  77. self, items: List[VectorItem], collection_name_with_prefix: str
  78. ) -> List[Dict[str, Any]]:
  79. """Convert VectorItem objects to Pinecone point format."""
  80. points = []
  81. for item in items:
  82. # Start with any existing metadata or an empty dict
  83. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  84. # Add text to metadata if available
  85. if "text" in item:
  86. metadata["text"] = item["text"]
  87. # Always add collection_name to metadata for filtering
  88. metadata["collection_name"] = collection_name_with_prefix
  89. point = {
  90. "id": item["id"],
  91. "values": item["vector"],
  92. "metadata": metadata,
  93. }
  94. points.append(point)
  95. return points
  96. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  97. """Get the collection name with prefix."""
  98. return f"{self.collection_prefix}_{collection_name}"
  99. def _normalize_distance(self, score: float) -> float:
  100. """Normalize distance score based on the metric used."""
  101. if self.metric.lower() == "cosine":
  102. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  103. return (score + 1.0) / 2.0
  104. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  105. # These are already suitable for ranking (smaller is better for Euclidean)
  106. return score
  107. else:
  108. # For other metrics, use as is
  109. return score
  110. def _result_to_get_result(self, matches: list) -> GetResult:
  111. """Convert Pinecone matches to GetResult format."""
  112. ids = []
  113. documents = []
  114. metadatas = []
  115. for match in matches:
  116. metadata = match.get("metadata", {})
  117. ids.append(match["id"])
  118. documents.append(metadata.get("text", ""))
  119. metadatas.append(metadata)
  120. return GetResult(
  121. **{
  122. "ids": [ids],
  123. "documents": [documents],
  124. "metadatas": [metadatas],
  125. }
  126. )
  127. def has_collection(self, collection_name: str) -> bool:
  128. """Check if a collection exists by searching for at least one item."""
  129. collection_name_with_prefix = self._get_collection_name_with_prefix(
  130. collection_name
  131. )
  132. try:
  133. # Search for at least 1 item with this collection name in metadata
  134. response = self.index.query(
  135. vector=[0.0] * self.dimension, # dummy vector
  136. top_k=1,
  137. filter={"collection_name": collection_name_with_prefix},
  138. include_metadata=False,
  139. )
  140. return len(response.matches) > 0
  141. except Exception as e:
  142. log.exception(
  143. f"Error checking collection '{collection_name_with_prefix}': {e}"
  144. )
  145. return False
  146. def delete_collection(self, collection_name: str) -> None:
  147. """Delete a collection by removing all vectors with the collection name in metadata."""
  148. collection_name_with_prefix = self._get_collection_name_with_prefix(
  149. collection_name
  150. )
  151. try:
  152. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  153. log.info(
  154. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  155. )
  156. except Exception as e:
  157. log.warning(
  158. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  159. )
  160. raise
  161. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  162. """Insert vectors into a collection."""
  163. if not items:
  164. log.warning("No items to insert")
  165. return
  166. collection_name_with_prefix = self._get_collection_name_with_prefix(
  167. collection_name
  168. )
  169. points = self._create_points(items, collection_name_with_prefix)
  170. # Insert in batches for better performance and reliability
  171. for i in range(0, len(points), BATCH_SIZE):
  172. batch = points[i : i + BATCH_SIZE]
  173. try:
  174. self.index.upsert(vectors=batch)
  175. log.debug(
  176. f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
  177. )
  178. except Exception as e:
  179. log.error(
  180. f"Error inserting batch into '{collection_name_with_prefix}': {e}"
  181. )
  182. raise
  183. log.info(
  184. f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
  185. )
  186. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  187. """Upsert (insert or update) vectors into a collection."""
  188. if not items:
  189. log.warning("No items to upsert")
  190. return
  191. collection_name_with_prefix = self._get_collection_name_with_prefix(
  192. collection_name
  193. )
  194. points = self._create_points(items, collection_name_with_prefix)
  195. # Upsert in batches
  196. for i in range(0, len(points), BATCH_SIZE):
  197. batch = points[i : i + BATCH_SIZE]
  198. try:
  199. self.index.upsert(vectors=batch)
  200. log.debug(
  201. f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
  202. )
  203. except Exception as e:
  204. log.error(
  205. f"Error upserting batch into '{collection_name_with_prefix}': {e}"
  206. )
  207. raise
  208. log.info(
  209. f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
  210. )
  211. def search(
  212. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  213. ) -> Optional[SearchResult]:
  214. """Search for similar vectors in a collection."""
  215. if not vectors or not vectors[0]:
  216. log.warning("No vectors provided for search")
  217. return None
  218. collection_name_with_prefix = self._get_collection_name_with_prefix(
  219. collection_name
  220. )
  221. if limit is None or limit <= 0:
  222. limit = NO_LIMIT
  223. try:
  224. # Search using the first vector (assuming this is the intended behavior)
  225. query_vector = vectors[0]
  226. # Perform the search
  227. query_response = self.index.query(
  228. vector=query_vector,
  229. top_k=limit,
  230. include_metadata=True,
  231. filter={"collection_name": collection_name_with_prefix},
  232. )
  233. if not query_response.matches:
  234. # Return empty result if no matches
  235. return SearchResult(
  236. ids=[[]],
  237. documents=[[]],
  238. metadatas=[[]],
  239. distances=[[]],
  240. )
  241. # Convert to GetResult format
  242. get_result = self._result_to_get_result(query_response.matches)
  243. # Calculate normalized distances based on metric
  244. distances = [
  245. [
  246. self._normalize_distance(match.score)
  247. for match in query_response.matches
  248. ]
  249. ]
  250. return SearchResult(
  251. ids=get_result.ids,
  252. documents=get_result.documents,
  253. metadatas=get_result.metadatas,
  254. distances=distances,
  255. )
  256. except Exception as e:
  257. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  258. return None
  259. def query(
  260. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  261. ) -> Optional[GetResult]:
  262. """Query vectors by metadata filter."""
  263. collection_name_with_prefix = self._get_collection_name_with_prefix(
  264. collection_name
  265. )
  266. if limit is None or limit <= 0:
  267. limit = NO_LIMIT
  268. try:
  269. # Create a zero vector for the dimension as Pinecone requires a vector
  270. zero_vector = [0.0] * self.dimension
  271. # Combine user filter with collection_name
  272. pinecone_filter = {"collection_name": collection_name_with_prefix}
  273. if filter:
  274. pinecone_filter.update(filter)
  275. # Perform metadata-only query
  276. query_response = self.index.query(
  277. vector=zero_vector,
  278. filter=pinecone_filter,
  279. top_k=limit,
  280. include_metadata=True,
  281. )
  282. return self._result_to_get_result(query_response.matches)
  283. except Exception as e:
  284. log.error(f"Error querying collection '{collection_name}': {e}")
  285. return None
  286. def get(self, collection_name: str) -> Optional[GetResult]:
  287. """Get all vectors in a collection."""
  288. collection_name_with_prefix = self._get_collection_name_with_prefix(
  289. collection_name
  290. )
  291. try:
  292. # Use a zero vector for fetching all entries
  293. zero_vector = [0.0] * self.dimension
  294. # Add filter to only get vectors for this collection
  295. query_response = self.index.query(
  296. vector=zero_vector,
  297. top_k=NO_LIMIT,
  298. include_metadata=True,
  299. filter={"collection_name": collection_name_with_prefix},
  300. )
  301. return self._result_to_get_result(query_response.matches)
  302. except Exception as e:
  303. log.error(f"Error getting collection '{collection_name}': {e}")
  304. return None
  305. def delete(
  306. self,
  307. collection_name: str,
  308. ids: Optional[List[str]] = None,
  309. filter: Optional[Dict] = None,
  310. ) -> None:
  311. """Delete vectors by IDs or filter."""
  312. collection_name_with_prefix = self._get_collection_name_with_prefix(
  313. collection_name
  314. )
  315. try:
  316. if ids:
  317. # Delete by IDs (in batches for large deletions)
  318. for i in range(0, len(ids), BATCH_SIZE):
  319. batch_ids = ids[i : i + BATCH_SIZE]
  320. # Note: When deleting by ID, we can't filter by collection_name
  321. # This is a limitation of Pinecone - be careful with ID uniqueness
  322. self.index.delete(ids=batch_ids)
  323. log.debug(
  324. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  325. )
  326. log.info(
  327. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  328. )
  329. elif filter:
  330. # Combine user filter with collection_name
  331. pinecone_filter = {"collection_name": collection_name_with_prefix}
  332. if filter:
  333. pinecone_filter.update(filter)
  334. # Delete by metadata filter
  335. self.index.delete(filter=pinecone_filter)
  336. log.info(
  337. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  338. )
  339. else:
  340. log.warning("No ids or filter provided for delete operation")
  341. except Exception as e:
  342. log.error(f"Error deleting from collection '{collection_name}': {e}")
  343. raise
  344. def reset(self) -> None:
  345. """Reset the database by deleting all collections."""
  346. try:
  347. self.index.delete(delete_all=True)
  348. log.info("All vectors successfully deleted from the index.")
  349. except Exception as e:
  350. log.error(f"Failed to reset Pinecone index: {e}")
  351. raise