pgvector.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. from sqlalchemy import (
  4. cast,
  5. column,
  6. create_engine,
  7. Column,
  8. Integer,
  9. MetaData,
  10. select,
  11. text,
  12. Text,
  13. Table,
  14. values,
  15. )
  16. from sqlalchemy.sql import true
  17. from sqlalchemy.pool import NullPool
  18. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  19. from sqlalchemy.dialects.postgresql import JSONB, array
  20. from pgvector.sqlalchemy import Vector
  21. from sqlalchemy.ext.mutable import MutableDict
  22. from sqlalchemy.exc import NoSuchTableError
  23. from open_webui.retrieval.vector.main import (
  24. VectorDBBase,
  25. VectorItem,
  26. SearchResult,
  27. GetResult,
  28. )
  29. from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  30. from open_webui.env import SRC_LOG_LEVELS
  31. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  32. Base = declarative_base()
  33. log = logging.getLogger(__name__)
  34. log.setLevel(SRC_LOG_LEVELS["RAG"])
  35. class DocumentChunk(Base):
  36. __tablename__ = "document_chunk"
  37. id = Column(Text, primary_key=True)
  38. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  39. collection_name = Column(Text, nullable=False)
  40. text = Column(Text, nullable=True)
  41. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  42. class PgvectorClient(VectorDBBase):
  43. def __init__(self) -> None:
  44. # if no pgvector uri, use the existing database connection
  45. if not PGVECTOR_DB_URL:
  46. from open_webui.internal.db import Session
  47. self.session = Session
  48. else:
  49. engine = create_engine(
  50. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  51. )
  52. SessionLocal = sessionmaker(
  53. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  54. )
  55. self.session = scoped_session(SessionLocal)
  56. try:
  57. # Ensure the pgvector extension is available
  58. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  59. # Check vector length consistency
  60. self.check_vector_length()
  61. # Create the tables if they do not exist
  62. # Base.metadata.create_all requires a bind (engine or connection)
  63. # Get the connection from the session
  64. connection = self.session.connection()
  65. Base.metadata.create_all(bind=connection)
  66. # Create an index on the vector column if it doesn't exist
  67. self.session.execute(
  68. text(
  69. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  70. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  71. )
  72. )
  73. self.session.execute(
  74. text(
  75. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  76. "ON document_chunk (collection_name);"
  77. )
  78. )
  79. self.session.commit()
  80. log.info("Initialization complete.")
  81. except Exception as e:
  82. self.session.rollback()
  83. log.exception(f"Error during initialization: {e}")
  84. raise
  85. def check_vector_length(self) -> None:
  86. """
  87. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  88. Raises an exception if there is a mismatch.
  89. """
  90. metadata = MetaData()
  91. try:
  92. # Attempt to reflect the 'document_chunk' table
  93. document_chunk_table = Table(
  94. "document_chunk", metadata, autoload_with=self.session.bind
  95. )
  96. except NoSuchTableError:
  97. # Table does not exist; no action needed
  98. return
  99. # Proceed to check the vector column
  100. if "vector" in document_chunk_table.columns:
  101. vector_column = document_chunk_table.columns["vector"]
  102. vector_type = vector_column.type
  103. if isinstance(vector_type, Vector):
  104. db_vector_length = vector_type.dim
  105. if db_vector_length != VECTOR_LENGTH:
  106. raise Exception(
  107. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  108. "Cannot change vector size after initialization without migrating the data."
  109. )
  110. else:
  111. raise Exception(
  112. "The 'vector' column exists but is not of type 'Vector'."
  113. )
  114. else:
  115. raise Exception(
  116. "The 'vector' column does not exist in the 'document_chunk' table."
  117. )
  118. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  119. # Adjust vector to have length VECTOR_LENGTH
  120. current_length = len(vector)
  121. if current_length < VECTOR_LENGTH:
  122. # Pad the vector with zeros
  123. vector += [0.0] * (VECTOR_LENGTH - current_length)
  124. elif current_length > VECTOR_LENGTH:
  125. raise Exception(
  126. f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
  127. )
  128. return vector
  129. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  130. try:
  131. new_items = []
  132. for item in items:
  133. vector = self.adjust_vector_length(item["vector"])
  134. new_chunk = DocumentChunk(
  135. id=item["id"],
  136. vector=vector,
  137. collection_name=collection_name,
  138. text=item["text"],
  139. vmetadata=item["metadata"],
  140. )
  141. new_items.append(new_chunk)
  142. self.session.bulk_save_objects(new_items)
  143. self.session.commit()
  144. log.info(
  145. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  146. )
  147. except Exception as e:
  148. self.session.rollback()
  149. log.exception(f"Error during insert: {e}")
  150. raise
  151. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  152. try:
  153. for item in items:
  154. vector = self.adjust_vector_length(item["vector"])
  155. existing = (
  156. self.session.query(DocumentChunk)
  157. .filter(DocumentChunk.id == item["id"])
  158. .first()
  159. )
  160. if existing:
  161. existing.vector = vector
  162. existing.text = item["text"]
  163. existing.vmetadata = item["metadata"]
  164. existing.collection_name = (
  165. collection_name # Update collection_name if necessary
  166. )
  167. else:
  168. new_chunk = DocumentChunk(
  169. id=item["id"],
  170. vector=vector,
  171. collection_name=collection_name,
  172. text=item["text"],
  173. vmetadata=item["metadata"],
  174. )
  175. self.session.add(new_chunk)
  176. self.session.commit()
  177. log.info(
  178. f"Upserted {len(items)} items into collection '{collection_name}'."
  179. )
  180. except Exception as e:
  181. self.session.rollback()
  182. log.exception(f"Error during upsert: {e}")
  183. raise
  184. def search(
  185. self,
  186. collection_name: str,
  187. vectors: List[List[float]],
  188. limit: Optional[int] = None,
  189. ) -> Optional[SearchResult]:
  190. try:
  191. if not vectors:
  192. return None
  193. # Adjust query vectors to VECTOR_LENGTH
  194. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  195. num_queries = len(vectors)
  196. def vector_expr(vector):
  197. return cast(array(vector), Vector(VECTOR_LENGTH))
  198. # Create the values for query vectors
  199. qid_col = column("qid", Integer)
  200. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  201. query_vectors = (
  202. values(qid_col, q_vector_col)
  203. .data(
  204. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  205. )
  206. .alias("query_vectors")
  207. )
  208. # Build the lateral subquery for each query vector
  209. subq = (
  210. select(
  211. DocumentChunk.id,
  212. DocumentChunk.text,
  213. DocumentChunk.vmetadata,
  214. (
  215. DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
  216. ).label("distance"),
  217. )
  218. .where(DocumentChunk.collection_name == collection_name)
  219. .order_by(
  220. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  221. )
  222. )
  223. if limit is not None:
  224. subq = subq.limit(limit)
  225. subq = subq.lateral("result")
  226. # Build the main query by joining query_vectors and the lateral subquery
  227. stmt = (
  228. select(
  229. query_vectors.c.qid,
  230. subq.c.id,
  231. subq.c.text,
  232. subq.c.vmetadata,
  233. subq.c.distance,
  234. )
  235. .select_from(query_vectors)
  236. .join(subq, true())
  237. .order_by(query_vectors.c.qid, subq.c.distance)
  238. )
  239. result_proxy = self.session.execute(stmt)
  240. results = result_proxy.all()
  241. ids = [[] for _ in range(num_queries)]
  242. distances = [[] for _ in range(num_queries)]
  243. documents = [[] for _ in range(num_queries)]
  244. metadatas = [[] for _ in range(num_queries)]
  245. if not results:
  246. return SearchResult(
  247. ids=ids,
  248. distances=distances,
  249. documents=documents,
  250. metadatas=metadatas,
  251. )
  252. for row in results:
  253. qid = int(row.qid)
  254. ids[qid].append(row.id)
  255. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  256. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  257. distances[qid].append((2.0 - row.distance) / 2.0)
  258. documents[qid].append(row.text)
  259. metadatas[qid].append(row.vmetadata)
  260. return SearchResult(
  261. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  262. )
  263. except Exception as e:
  264. log.exception(f"Error during search: {e}")
  265. return None
  266. def query(
  267. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  268. ) -> Optional[GetResult]:
  269. try:
  270. query = self.session.query(DocumentChunk).filter(
  271. DocumentChunk.collection_name == collection_name
  272. )
  273. for key, value in filter.items():
  274. query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
  275. if limit is not None:
  276. query = query.limit(limit)
  277. results = query.all()
  278. if not results:
  279. return None
  280. ids = [[result.id for result in results]]
  281. documents = [[result.text for result in results]]
  282. metadatas = [[result.vmetadata for result in results]]
  283. return GetResult(
  284. ids=ids,
  285. documents=documents,
  286. metadatas=metadatas,
  287. )
  288. except Exception as e:
  289. log.exception(f"Error during query: {e}")
  290. return None
  291. def get(
  292. self, collection_name: str, limit: Optional[int] = None
  293. ) -> Optional[GetResult]:
  294. try:
  295. query = self.session.query(DocumentChunk).filter(
  296. DocumentChunk.collection_name == collection_name
  297. )
  298. if limit is not None:
  299. query = query.limit(limit)
  300. results = query.all()
  301. if not results:
  302. return None
  303. ids = [[result.id for result in results]]
  304. documents = [[result.text for result in results]]
  305. metadatas = [[result.vmetadata for result in results]]
  306. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  307. except Exception as e:
  308. log.exception(f"Error during get: {e}")
  309. return None
  310. def delete(
  311. self,
  312. collection_name: str,
  313. ids: Optional[List[str]] = None,
  314. filter: Optional[Dict[str, Any]] = None,
  315. ) -> None:
  316. try:
  317. query = self.session.query(DocumentChunk).filter(
  318. DocumentChunk.collection_name == collection_name
  319. )
  320. if ids:
  321. query = query.filter(DocumentChunk.id.in_(ids))
  322. if filter:
  323. for key, value in filter.items():
  324. query = query.filter(
  325. DocumentChunk.vmetadata[key].astext == str(value)
  326. )
  327. deleted = query.delete(synchronize_session=False)
  328. self.session.commit()
  329. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  330. except Exception as e:
  331. self.session.rollback()
  332. log.exception(f"Error during delete: {e}")
  333. raise
  334. def reset(self) -> None:
  335. try:
  336. deleted = self.session.query(DocumentChunk).delete()
  337. self.session.commit()
  338. log.info(
  339. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  340. )
  341. except Exception as e:
  342. self.session.rollback()
  343. log.exception(f"Error during reset: {e}")
  344. raise
  345. def close(self) -> None:
  346. pass
  347. def has_collection(self, collection_name: str) -> bool:
  348. try:
  349. exists = (
  350. self.session.query(DocumentChunk)
  351. .filter(DocumentChunk.collection_name == collection_name)
  352. .first()
  353. is not None
  354. )
  355. return exists
  356. except Exception as e:
  357. log.exception(f"Error checking collection existence: {e}")
  358. return False
  359. def delete_collection(self, collection_name: str) -> None:
  360. self.delete(collection_name)
  361. log.info(f"Collection '{collection_name}' deleted.")