pinecone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import asyncio
  4. from pinecone import Pinecone, ServerlessSpec
  5. # Helper for building consistent metadata
  6. def build_metadata(
  7. *,
  8. source: str,
  9. type_: str,
  10. user_id: str,
  11. chat_id: Optional[str] = None,
  12. filename: Optional[str] = None,
  13. text: Optional[str] = None,
  14. topic: Optional[str] = None,
  15. model: Optional[str] = None,
  16. vector_dim: Optional[int] = None,
  17. extra: Optional[Dict[str, Any]] = None,
  18. collection_name: Optional[str] = None,
  19. ) -> Dict[str, Any]:
  20. from datetime import datetime
  21. metadata = {
  22. "source": source,
  23. "type": type_,
  24. "user_id": user_id,
  25. "timestamp": datetime.utcnow().isoformat() + "Z",
  26. }
  27. if chat_id:
  28. metadata["chat_id"] = chat_id
  29. if filename:
  30. metadata["filename"] = filename
  31. if text:
  32. metadata["text"] = text
  33. if topic:
  34. metadata["topic"] = topic
  35. if model:
  36. metadata["model"] = model
  37. if vector_dim:
  38. metadata["vector_dim"] = vector_dim
  39. if collection_name:
  40. metadata["collection_name"] = collection_name
  41. if extra:
  42. metadata.update(extra)
  43. return metadata
  44. from open_webui.retrieval.vector.main import (
  45. VectorDBBase,
  46. VectorItem,
  47. SearchResult,
  48. GetResult,
  49. )
  50. from open_webui.config import (
  51. PINECONE_API_KEY,
  52. PINECONE_ENVIRONMENT,
  53. PINECONE_INDEX_NAME,
  54. PINECONE_DIMENSION,
  55. PINECONE_METRIC,
  56. PINECONE_CLOUD,
  57. )
  58. from open_webui.env import SRC_LOG_LEVELS
  59. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  60. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  61. log = logging.getLogger(__name__)
  62. log.setLevel(SRC_LOG_LEVELS["RAG"])
  63. class PineconeClient(VectorDBBase):
  64. def __init__(self):
  65. from open_webui.config import PINECONE_NAMESPACE
  66. self.namespace = PINECONE_NAMESPACE
  67. # Validate required configuration
  68. self._validate_config()
  69. # Store configuration values
  70. self.api_key = PINECONE_API_KEY
  71. self.environment = PINECONE_ENVIRONMENT
  72. self.index_name = PINECONE_INDEX_NAME
  73. self.dimension = PINECONE_DIMENSION
  74. self.metric = PINECONE_METRIC
  75. self.cloud = PINECONE_CLOUD
  76. # Initialize Pinecone client
  77. self.client = Pinecone(api_key=self.api_key)
  78. # Create index if it doesn't exist
  79. self._initialize_index()
  80. def _validate_config(self) -> None:
  81. """Validate that all required configuration variables are set."""
  82. missing_vars = []
  83. if not PINECONE_API_KEY:
  84. missing_vars.append("PINECONE_API_KEY")
  85. if not PINECONE_ENVIRONMENT:
  86. missing_vars.append("PINECONE_ENVIRONMENT")
  87. if not PINECONE_INDEX_NAME:
  88. missing_vars.append("PINECONE_INDEX_NAME")
  89. if not PINECONE_DIMENSION:
  90. missing_vars.append("PINECONE_DIMENSION")
  91. if not PINECONE_CLOUD:
  92. missing_vars.append("PINECONE_CLOUD")
  93. if missing_vars:
  94. raise ValueError(
  95. f"Required configuration missing: {', '.join(missing_vars)}"
  96. )
  97. def _initialize_index(self) -> None:
  98. """Initialize the Pinecone index."""
  99. try:
  100. # Check if index exists
  101. if self.index_name not in self.client.list_indexes().names():
  102. log.info(f"Creating Pinecone index '{self.index_name}'...")
  103. self.client.create_index(
  104. name=self.index_name,
  105. dimension=self.dimension,
  106. metric=self.metric,
  107. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  108. )
  109. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  110. else:
  111. log.info(f"Using existing Pinecone index '{self.index_name}'")
  112. # Connect to the index
  113. self.index = self.client.Index(self.index_name)
  114. except Exception as e:
  115. log.error(f"Failed to initialize Pinecone index: {e}")
  116. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  117. def _create_points(
  118. self, items: List[VectorItem], collection_name_with_prefix: str
  119. ) -> List[Dict[str, Any]]:
  120. """Convert VectorItem objects to Pinecone point format."""
  121. points = []
  122. for item in items:
  123. user_id = item.get("metadata", {}).get("created_by", "unknown")
  124. chat_id = item.get("metadata", {}).get("chat_id")
  125. filename = item.get("metadata", {}).get("name")
  126. text = item.get("text")
  127. model = item.get("metadata", {}).get("model")
  128. topic = item.get("metadata", {}).get("topic")
  129. # Infer source from filename or fallback
  130. raw_source = item.get("metadata", {}).get("source", "")
  131. inferred_source = "knowledge"
  132. if raw_source == filename or (isinstance(raw_source, str) and raw_source.endswith((".pdf", ".txt", ".docx"))):
  133. inferred_source = "chat" if item.get("metadata", {}).get("created_by") else "knowledge"
  134. else:
  135. inferred_source = raw_source or "knowledge"
  136. metadata = build_metadata(
  137. source=inferred_source,
  138. type_="upload",
  139. user_id=user_id,
  140. chat_id=chat_id,
  141. filename=filename,
  142. text=text,
  143. model=model,
  144. topic=topic,
  145. collection_name=collection_name_with_prefix,
  146. )
  147. point = {
  148. "id": item["id"],
  149. "values": item["vector"],
  150. "metadata": metadata,
  151. }
  152. points.append(point)
  153. return points
  154. def _get_namespace(self) -> str:
  155. """Get the namespace from the environment variable."""
  156. return self.namespace
  157. def _normalize_distance(self, score: float) -> float:
  158. """Normalize distance score based on the metric used."""
  159. if self.metric.lower() == "cosine":
  160. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  161. return (score + 1.0) / 2.0
  162. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  163. # These are already suitable for ranking (smaller is better for Euclidean)
  164. return score
  165. else:
  166. # For other metrics, use as is
  167. return score
  168. def _result_to_get_result(self, matches: list) -> GetResult:
  169. """Convert Pinecone matches to GetResult format."""
  170. ids = []
  171. documents = []
  172. metadatas = []
  173. for match in matches:
  174. metadata = match.get("metadata", {})
  175. ids.append(match["id"])
  176. documents.append(metadata.get("text", ""))
  177. metadatas.append(metadata)
  178. return GetResult(
  179. **{
  180. "ids": [ids],
  181. "documents": [documents],
  182. "metadatas": [metadatas],
  183. }
  184. )
  185. def has_collection(self, collection_name: str) -> bool:
  186. """Check if a collection exists by searching for at least one item."""
  187. collection_name_with_prefix = self._get_namespace()
  188. try:
  189. # Search for at least 1 item with this collection name in metadata
  190. response = self.index.query(
  191. vector=[0.0] * self.dimension, # dummy vector
  192. top_k=1,
  193. filter={"collection_name": collection_name_with_prefix},
  194. include_metadata=False,
  195. )
  196. return len(response.matches) > 0
  197. except Exception as e:
  198. log.exception(
  199. f"Error checking collection '{collection_name_with_prefix}': {e}"
  200. )
  201. return False
  202. def delete_collection(self, collection_name: str) -> None:
  203. """Delete a collection by removing all vectors with the collection name in metadata."""
  204. collection_name_with_prefix = self._get_namespace()
  205. try:
  206. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  207. log.info(
  208. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  209. )
  210. except Exception as e:
  211. log.warning(
  212. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  213. )
  214. raise
  215. async def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  216. """Insert vectors into a collection."""
  217. import time
  218. if not items:
  219. log.warning("No items to insert")
  220. return
  221. collection_name_with_prefix = self._get_namespace()
  222. points = self._create_points(items, collection_name_with_prefix)
  223. # Insert in batches for better performance and reliability
  224. for i in range(0, len(points), BATCH_SIZE):
  225. batch = points[i : i + BATCH_SIZE]
  226. try:
  227. start = time.time()
  228. await asyncio.to_thread(self.index.upsert, vectors=batch)
  229. elapsed = int((time.time() - start) * 1000)
  230. # Log line removed as requested
  231. except Exception as e:
  232. log.error(
  233. f"Error inserting batch into '{collection_name_with_prefix}': {e}"
  234. )
  235. raise
  236. log.info(
  237. f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
  238. )
  239. async def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  240. """Upsert (insert or update) vectors into a collection."""
  241. import time
  242. if not items:
  243. log.warning("No items to upsert")
  244. return
  245. collection_name_with_prefix = self._get_namespace()
  246. points = self._create_points(items, collection_name_with_prefix)
  247. # Upsert in batches
  248. for i in range(0, len(points), BATCH_SIZE):
  249. batch = points[i : i + BATCH_SIZE]
  250. try:
  251. start = time.time()
  252. await asyncio.to_thread(self.index.upsert, vectors=batch)
  253. elapsed = int((time.time() - start) * 1000)
  254. # Log line removed as requested
  255. except Exception as e:
  256. log.error(
  257. f"Error upserting batch into '{collection_name_with_prefix}': {e}"
  258. )
  259. raise
  260. log.info(
  261. f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
  262. )
  263. def search(
  264. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  265. ) -> Optional[SearchResult]:
  266. """Search for similar vectors in a collection."""
  267. if not vectors or not vectors[0]:
  268. log.warning("No vectors provided for search")
  269. return None
  270. collection_name_with_prefix = self._get_namespace()
  271. if limit is None or limit <= 0:
  272. limit = NO_LIMIT
  273. try:
  274. # Search using the first vector (assuming this is the intended behavior)
  275. query_vector = vectors[0]
  276. # Perform the search
  277. query_response = self.index.query(
  278. vector=query_vector,
  279. top_k=limit,
  280. include_metadata=True,
  281. filter={"collection_name": collection_name_with_prefix},
  282. )
  283. if not query_response.matches:
  284. # Return empty result if no matches
  285. return SearchResult(
  286. ids=[[]],
  287. documents=[[]],
  288. metadatas=[[]],
  289. distances=[[]],
  290. )
  291. # Convert to GetResult format
  292. get_result = self._result_to_get_result(query_response.matches)
  293. # Calculate normalized distances based on metric
  294. distances = [
  295. [
  296. self._normalize_distance(match.score)
  297. for match in query_response.matches
  298. ]
  299. ]
  300. return SearchResult(
  301. ids=get_result.ids,
  302. documents=get_result.documents,
  303. metadatas=get_result.metadatas,
  304. distances=distances,
  305. )
  306. except Exception as e:
  307. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  308. return None
  309. def query(
  310. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  311. ) -> Optional[GetResult]:
  312. """Query vectors by metadata filter."""
  313. collection_name_with_prefix = self._get_namespace()
  314. if limit is None or limit <= 0:
  315. limit = NO_LIMIT
  316. try:
  317. # Create a zero vector for the dimension as Pinecone requires a vector
  318. zero_vector = [0.0] * self.dimension
  319. # Combine user filter with collection_name
  320. pinecone_filter = {"collection_name": collection_name_with_prefix}
  321. if filter:
  322. pinecone_filter.update(filter)
  323. # Perform metadata-only query
  324. query_response = self.index.query(
  325. vector=zero_vector,
  326. filter=pinecone_filter,
  327. top_k=limit,
  328. include_metadata=True,
  329. )
  330. return self._result_to_get_result(query_response.matches)
  331. except Exception as e:
  332. log.error(f"Error querying collection '{collection_name}': {e}")
  333. return None
  334. def get(self, collection_name: str) -> Optional[GetResult]:
  335. """Get all vectors in a collection."""
  336. collection_name_with_prefix = self._get_namespace()
  337. try:
  338. # Use a zero vector for fetching all entries
  339. zero_vector = [0.0] * self.dimension
  340. # Add filter to only get vectors for this collection
  341. query_response = self.index.query(
  342. vector=zero_vector,
  343. top_k=NO_LIMIT,
  344. include_metadata=True,
  345. filter={"collection_name": collection_name_with_prefix},
  346. )
  347. return self._result_to_get_result(query_response.matches)
  348. except Exception as e:
  349. log.error(f"Error getting collection '{collection_name}': {e}")
  350. return None
  351. async def delete(
  352. self,
  353. collection_name: str,
  354. ids: Optional[List[str]] = None,
  355. filter: Optional[Dict] = None,
  356. ) -> None:
  357. """Delete vectors by IDs or filter."""
  358. import time
  359. collection_name_with_prefix = self._get_namespace()
  360. try:
  361. if ids:
  362. # Delete by IDs (in batches for large deletions)
  363. for i in range(0, len(ids), BATCH_SIZE):
  364. batch_ids = ids[i : i + BATCH_SIZE]
  365. # Note: When deleting by ID, we can't filter by collection_name
  366. # This is a limitation of Pinecone - be careful with ID uniqueness
  367. start = time.time()
  368. await asyncio.to_thread(self.index.delete, ids=batch_ids)
  369. elapsed = int((time.time() - start) * 1000)
  370. # Log line removed as requested
  371. log.info(
  372. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  373. )
  374. elif filter:
  375. # Combine user filter with collection_name
  376. pinecone_filter = {"collection_name": collection_name_with_prefix}
  377. if filter:
  378. pinecone_filter.update(filter)
  379. # Delete by metadata filter
  380. self.index.delete(filter=pinecone_filter)
  381. log.info(
  382. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  383. )
  384. else:
  385. log.warning("No ids or filter provided for delete operation")
  386. except Exception as e:
  387. log.error(f"Error deleting from collection '{collection_name}': {e}")
  388. raise
  389. def reset(self) -> None:
  390. """Reset the database by deleting all collections."""
  391. try:
  392. self.index.delete(delete_all=True)
  393. log.info("All vectors successfully deleted from the index.")
  394. except Exception as e:
  395. log.error(f"Failed to reset Pinecone index: {e}")
  396. raise