pgvector.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool, QueuePool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.main import (
  28. VectorDBBase,
  29. VectorItem,
  30. SearchResult,
  31. GetResult,
  32. )
  33. from open_webui.config import (
  34. PGVECTOR_DB_URL,
  35. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  36. PGVECTOR_PGCRYPTO,
  37. PGVECTOR_PGCRYPTO_KEY,
  38. PGVECTOR_POOL_SIZE,
  39. PGVECTOR_POOL_MAX_OVERFLOW,
  40. PGVECTOR_POOL_TIMEOUT,
  41. PGVECTOR_POOL_RECYCLE,
  42. )
  43. from open_webui.env import SRC_LOG_LEVELS
  44. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  45. Base = declarative_base()
  46. log = logging.getLogger(__name__)
  47. log.setLevel(SRC_LOG_LEVELS["RAG"])
  48. def pgcrypto_encrypt(val, key):
  49. return func.pgp_sym_encrypt(val, literal(key))
  50. def pgcrypto_decrypt(col, key, outtype="text"):
  51. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  52. class DocumentChunk(Base):
  53. __tablename__ = "document_chunk"
  54. id = Column(Text, primary_key=True)
  55. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  56. collection_name = Column(Text, nullable=False)
  57. if PGVECTOR_PGCRYPTO:
  58. text = Column(LargeBinary, nullable=True)
  59. vmetadata = Column(LargeBinary, nullable=True)
  60. else:
  61. text = Column(Text, nullable=True)
  62. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  63. class PgvectorClient(VectorDBBase):
  64. def __init__(self) -> None:
  65. # if no pgvector uri, use the existing database connection
  66. if not PGVECTOR_DB_URL:
  67. from open_webui.internal.db import Session
  68. self.session = Session
  69. else:
  70. if isinstance(PGVECTOR_POOL_SIZE, int):
  71. if PGVECTOR_POOL_SIZE > 0:
  72. engine = create_engine(
  73. PGVECTOR_DB_URL,
  74. pool_size=PGVECTOR_POOL_SIZE,
  75. max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
  76. pool_timeout=PGVECTOR_POOL_TIMEOUT,
  77. pool_recycle=PGVECTOR_POOL_RECYCLE,
  78. pool_pre_ping=True,
  79. poolclass=QueuePool,
  80. )
  81. else:
  82. engine = create_engine(
  83. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  84. )
  85. else:
  86. engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
  87. SessionLocal = sessionmaker(
  88. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  89. )
  90. self.session = scoped_session(SessionLocal)
  91. try:
  92. # Ensure the pgvector extension is available
  93. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  94. if PGVECTOR_PGCRYPTO:
  95. # Ensure the pgcrypto extension is available for encryption
  96. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
  97. if not PGVECTOR_PGCRYPTO_KEY:
  98. raise ValueError(
  99. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  100. )
  101. # Check vector length consistency
  102. self.check_vector_length()
  103. # Create the tables if they do not exist
  104. # Base.metadata.create_all requires a bind (engine or connection)
  105. # Get the connection from the session
  106. connection = self.session.connection()
  107. Base.metadata.create_all(bind=connection)
  108. # Create an index on the vector column if it doesn't exist
  109. self.session.execute(
  110. text(
  111. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  112. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  113. )
  114. )
  115. self.session.execute(
  116. text(
  117. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  118. "ON document_chunk (collection_name);"
  119. )
  120. )
  121. self.session.commit()
  122. log.info("Initialization complete.")
  123. except Exception as e:
  124. self.session.rollback()
  125. log.exception(f"Error during initialization: {e}")
  126. raise
  127. def check_vector_length(self) -> None:
  128. """
  129. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  130. Raises an exception if there is a mismatch.
  131. """
  132. metadata = MetaData()
  133. try:
  134. # Attempt to reflect the 'document_chunk' table
  135. document_chunk_table = Table(
  136. "document_chunk", metadata, autoload_with=self.session.bind
  137. )
  138. except NoSuchTableError:
  139. # Table does not exist; no action needed
  140. return
  141. # Proceed to check the vector column
  142. if "vector" in document_chunk_table.columns:
  143. vector_column = document_chunk_table.columns["vector"]
  144. vector_type = vector_column.type
  145. if isinstance(vector_type, Vector):
  146. db_vector_length = vector_type.dim
  147. if db_vector_length != VECTOR_LENGTH:
  148. raise Exception(
  149. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  150. "Cannot change vector size after initialization without migrating the data."
  151. )
  152. else:
  153. raise Exception(
  154. "The 'vector' column exists but is not of type 'Vector'."
  155. )
  156. else:
  157. raise Exception(
  158. "The 'vector' column does not exist in the 'document_chunk' table."
  159. )
  160. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  161. # Adjust vector to have length VECTOR_LENGTH
  162. current_length = len(vector)
  163. if current_length < VECTOR_LENGTH:
  164. # Pad the vector with zeros
  165. vector += [0.0] * (VECTOR_LENGTH - current_length)
  166. elif current_length > VECTOR_LENGTH:
  167. # Truncate the vector to VECTOR_LENGTH
  168. vector = vector[:VECTOR_LENGTH]
  169. return vector
  170. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  171. try:
  172. if PGVECTOR_PGCRYPTO:
  173. for item in items:
  174. vector = self.adjust_vector_length(item["vector"])
  175. # Use raw SQL for BYTEA/pgcrypto
  176. self.session.execute(
  177. text(
  178. """
  179. INSERT INTO document_chunk
  180. (id, vector, collection_name, text, vmetadata)
  181. VALUES (
  182. :id, :vector, :collection_name,
  183. pgp_sym_encrypt(:text, :key),
  184. pgp_sym_encrypt(:metadata::text, :key)
  185. )
  186. ON CONFLICT (id) DO NOTHING
  187. """
  188. ),
  189. {
  190. "id": item["id"],
  191. "vector": vector,
  192. "collection_name": collection_name,
  193. "text": item["text"],
  194. "metadata": json.dumps(item["metadata"]),
  195. "key": PGVECTOR_PGCRYPTO_KEY,
  196. },
  197. )
  198. self.session.commit()
  199. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  200. else:
  201. new_items = []
  202. for item in items:
  203. vector = self.adjust_vector_length(item["vector"])
  204. new_chunk = DocumentChunk(
  205. id=item["id"],
  206. vector=vector,
  207. collection_name=collection_name,
  208. text=item["text"],
  209. vmetadata=item["metadata"],
  210. )
  211. new_items.append(new_chunk)
  212. self.session.bulk_save_objects(new_items)
  213. self.session.commit()
  214. log.info(
  215. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  216. )
  217. except Exception as e:
  218. self.session.rollback()
  219. log.exception(f"Error during insert: {e}")
  220. raise
  221. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  222. try:
  223. if PGVECTOR_PGCRYPTO:
  224. for item in items:
  225. vector = self.adjust_vector_length(item["vector"])
  226. self.session.execute(
  227. text(
  228. """
  229. INSERT INTO document_chunk
  230. (id, vector, collection_name, text, vmetadata)
  231. VALUES (
  232. :id, :vector, :collection_name,
  233. pgp_sym_encrypt(:text, :key),
  234. pgp_sym_encrypt(:metadata::text, :key)
  235. )
  236. ON CONFLICT (id) DO UPDATE SET
  237. vector = EXCLUDED.vector,
  238. collection_name = EXCLUDED.collection_name,
  239. text = EXCLUDED.text,
  240. vmetadata = EXCLUDED.vmetadata
  241. """
  242. ),
  243. {
  244. "id": item["id"],
  245. "vector": vector,
  246. "collection_name": collection_name,
  247. "text": item["text"],
  248. "metadata": json.dumps(item["metadata"]),
  249. "key": PGVECTOR_PGCRYPTO_KEY,
  250. },
  251. )
  252. self.session.commit()
  253. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  254. else:
  255. for item in items:
  256. vector = self.adjust_vector_length(item["vector"])
  257. existing = (
  258. self.session.query(DocumentChunk)
  259. .filter(DocumentChunk.id == item["id"])
  260. .first()
  261. )
  262. if existing:
  263. existing.vector = vector
  264. existing.text = item["text"]
  265. existing.vmetadata = item["metadata"]
  266. existing.collection_name = (
  267. collection_name # Update collection_name if necessary
  268. )
  269. else:
  270. new_chunk = DocumentChunk(
  271. id=item["id"],
  272. vector=vector,
  273. collection_name=collection_name,
  274. text=item["text"],
  275. vmetadata=item["metadata"],
  276. )
  277. self.session.add(new_chunk)
  278. self.session.commit()
  279. log.info(
  280. f"Upserted {len(items)} items into collection '{collection_name}'."
  281. )
  282. except Exception as e:
  283. self.session.rollback()
  284. log.exception(f"Error during upsert: {e}")
  285. raise
  286. def search(
  287. self,
  288. collection_name: str,
  289. vectors: List[List[float]],
  290. limit: Optional[int] = None,
  291. ) -> Optional[SearchResult]:
  292. try:
  293. if not vectors:
  294. return None
  295. # Adjust query vectors to VECTOR_LENGTH
  296. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  297. num_queries = len(vectors)
  298. def vector_expr(vector):
  299. return cast(array(vector), Vector(VECTOR_LENGTH))
  300. # Create the values for query vectors
  301. qid_col = column("qid", Integer)
  302. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  303. query_vectors = (
  304. values(qid_col, q_vector_col)
  305. .data(
  306. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  307. )
  308. .alias("query_vectors")
  309. )
  310. result_fields = [
  311. DocumentChunk.id,
  312. ]
  313. if PGVECTOR_PGCRYPTO:
  314. result_fields.append(
  315. pgcrypto_decrypt(
  316. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  317. ).label("text")
  318. )
  319. result_fields.append(
  320. pgcrypto_decrypt(
  321. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  322. ).label("vmetadata")
  323. )
  324. else:
  325. result_fields.append(DocumentChunk.text)
  326. result_fields.append(DocumentChunk.vmetadata)
  327. result_fields.append(
  328. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  329. "distance"
  330. )
  331. )
  332. # Build the lateral subquery for each query vector
  333. subq = (
  334. select(*result_fields)
  335. .where(DocumentChunk.collection_name == collection_name)
  336. .order_by(
  337. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  338. )
  339. )
  340. if limit is not None:
  341. subq = subq.limit(limit)
  342. subq = subq.lateral("result")
  343. # Build the main query by joining query_vectors and the lateral subquery
  344. stmt = (
  345. select(
  346. query_vectors.c.qid,
  347. subq.c.id,
  348. subq.c.text,
  349. subq.c.vmetadata,
  350. subq.c.distance,
  351. )
  352. .select_from(query_vectors)
  353. .join(subq, true())
  354. .order_by(query_vectors.c.qid, subq.c.distance)
  355. )
  356. result_proxy = self.session.execute(stmt)
  357. results = result_proxy.all()
  358. ids = [[] for _ in range(num_queries)]
  359. distances = [[] for _ in range(num_queries)]
  360. documents = [[] for _ in range(num_queries)]
  361. metadatas = [[] for _ in range(num_queries)]
  362. if not results:
  363. return SearchResult(
  364. ids=ids,
  365. distances=distances,
  366. documents=documents,
  367. metadatas=metadatas,
  368. )
  369. for row in results:
  370. qid = int(row.qid)
  371. ids[qid].append(row.id)
  372. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  373. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  374. distances[qid].append((2.0 - row.distance) / 2.0)
  375. documents[qid].append(row.text)
  376. metadatas[qid].append(row.vmetadata)
  377. return SearchResult(
  378. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  379. )
  380. except Exception as e:
  381. log.exception(f"Error during search: {e}")
  382. return None
  383. def query(
  384. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  385. ) -> Optional[GetResult]:
  386. try:
  387. if PGVECTOR_PGCRYPTO:
  388. # Build where clause for vmetadata filter
  389. where_clauses = [DocumentChunk.collection_name == collection_name]
  390. for key, value in filter.items():
  391. # decrypt then check key: JSON filter after decryption
  392. where_clauses.append(
  393. pgcrypto_decrypt(
  394. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  395. )[key].astext
  396. == str(value)
  397. )
  398. stmt = select(
  399. DocumentChunk.id,
  400. pgcrypto_decrypt(
  401. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  402. ).label("text"),
  403. pgcrypto_decrypt(
  404. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  405. ).label("vmetadata"),
  406. ).where(*where_clauses)
  407. if limit is not None:
  408. stmt = stmt.limit(limit)
  409. results = self.session.execute(stmt).all()
  410. else:
  411. query = self.session.query(DocumentChunk).filter(
  412. DocumentChunk.collection_name == collection_name
  413. )
  414. for key, value in filter.items():
  415. query = query.filter(
  416. DocumentChunk.vmetadata[key].astext == str(value)
  417. )
  418. if limit is not None:
  419. query = query.limit(limit)
  420. results = query.all()
  421. if not results:
  422. return None
  423. ids = [[result.id for result in results]]
  424. documents = [[result.text for result in results]]
  425. metadatas = [[result.vmetadata for result in results]]
  426. return GetResult(
  427. ids=ids,
  428. documents=documents,
  429. metadatas=metadatas,
  430. )
  431. except Exception as e:
  432. log.exception(f"Error during query: {e}")
  433. return None
  434. def get(
  435. self, collection_name: str, limit: Optional[int] = None
  436. ) -> Optional[GetResult]:
  437. try:
  438. if PGVECTOR_PGCRYPTO:
  439. stmt = select(
  440. DocumentChunk.id,
  441. pgcrypto_decrypt(
  442. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  443. ).label("text"),
  444. pgcrypto_decrypt(
  445. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  446. ).label("vmetadata"),
  447. ).where(DocumentChunk.collection_name == collection_name)
  448. if limit is not None:
  449. stmt = stmt.limit(limit)
  450. results = self.session.execute(stmt).all()
  451. ids = [[row.id for row in results]]
  452. documents = [[row.text for row in results]]
  453. metadatas = [[row.vmetadata for row in results]]
  454. else:
  455. query = self.session.query(DocumentChunk).filter(
  456. DocumentChunk.collection_name == collection_name
  457. )
  458. if limit is not None:
  459. query = query.limit(limit)
  460. results = query.all()
  461. if not results:
  462. return None
  463. ids = [[result.id for result in results]]
  464. documents = [[result.text for result in results]]
  465. metadatas = [[result.vmetadata for result in results]]
  466. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  467. except Exception as e:
  468. log.exception(f"Error during get: {e}")
  469. return None
  470. def delete(
  471. self,
  472. collection_name: str,
  473. ids: Optional[List[str]] = None,
  474. filter: Optional[Dict[str, Any]] = None,
  475. ) -> None:
  476. try:
  477. if PGVECTOR_PGCRYPTO:
  478. wheres = [DocumentChunk.collection_name == collection_name]
  479. if ids:
  480. wheres.append(DocumentChunk.id.in_(ids))
  481. if filter:
  482. for key, value in filter.items():
  483. wheres.append(
  484. pgcrypto_decrypt(
  485. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  486. )[key].astext
  487. == str(value)
  488. )
  489. stmt = DocumentChunk.__table__.delete().where(*wheres)
  490. result = self.session.execute(stmt)
  491. deleted = result.rowcount
  492. else:
  493. query = self.session.query(DocumentChunk).filter(
  494. DocumentChunk.collection_name == collection_name
  495. )
  496. if ids:
  497. query = query.filter(DocumentChunk.id.in_(ids))
  498. if filter:
  499. for key, value in filter.items():
  500. query = query.filter(
  501. DocumentChunk.vmetadata[key].astext == str(value)
  502. )
  503. deleted = query.delete(synchronize_session=False)
  504. self.session.commit()
  505. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  506. except Exception as e:
  507. self.session.rollback()
  508. log.exception(f"Error during delete: {e}")
  509. raise
  510. def reset(self) -> None:
  511. try:
  512. deleted = self.session.query(DocumentChunk).delete()
  513. self.session.commit()
  514. log.info(
  515. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  516. )
  517. except Exception as e:
  518. self.session.rollback()
  519. log.exception(f"Error during reset: {e}")
  520. raise
  521. def close(self) -> None:
  522. pass
  523. def has_collection(self, collection_name: str) -> bool:
  524. try:
  525. exists = (
  526. self.session.query(DocumentChunk)
  527. .filter(DocumentChunk.collection_name == collection_name)
  528. .first()
  529. is not None
  530. )
  531. return exists
  532. except Exception as e:
  533. log.exception(f"Error checking collection existence: {e}")
  534. return False
  535. def delete_collection(self, collection_name: str) -> None:
  536. self.delete(collection_name)
  537. log.info(f"Collection '{collection_name}' deleted.")