pgvector.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool, QueuePool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.utils import stringify_metadata
  28. from open_webui.retrieval.vector.main import (
  29. VectorDBBase,
  30. VectorItem,
  31. SearchResult,
  32. GetResult,
  33. )
  34. from open_webui.config import (
  35. PGVECTOR_DB_URL,
  36. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  37. PGVECTOR_PGCRYPTO,
  38. PGVECTOR_PGCRYPTO_KEY,
  39. PGVECTOR_POOL_SIZE,
  40. PGVECTOR_POOL_MAX_OVERFLOW,
  41. PGVECTOR_POOL_TIMEOUT,
  42. PGVECTOR_POOL_RECYCLE,
  43. )
  44. from open_webui.env import SRC_LOG_LEVELS
  45. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  46. Base = declarative_base()
  47. log = logging.getLogger(__name__)
  48. log.setLevel(SRC_LOG_LEVELS["RAG"])
  49. def pgcrypto_encrypt(val, key):
  50. return func.pgp_sym_encrypt(val, literal(key))
  51. def pgcrypto_decrypt(col, key, outtype="text"):
  52. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  53. class DocumentChunk(Base):
  54. __tablename__ = "document_chunk"
  55. id = Column(Text, primary_key=True)
  56. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  57. collection_name = Column(Text, nullable=False)
  58. if PGVECTOR_PGCRYPTO:
  59. text = Column(LargeBinary, nullable=True)
  60. vmetadata = Column(LargeBinary, nullable=True)
  61. else:
  62. text = Column(Text, nullable=True)
  63. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  64. class PgvectorClient(VectorDBBase):
  65. def __init__(self) -> None:
  66. # if no pgvector uri, use the existing database connection
  67. if not PGVECTOR_DB_URL:
  68. from open_webui.internal.db import Session
  69. self.session = Session
  70. else:
  71. if isinstance(PGVECTOR_POOL_SIZE, int):
  72. if PGVECTOR_POOL_SIZE > 0:
  73. engine = create_engine(
  74. PGVECTOR_DB_URL,
  75. pool_size=PGVECTOR_POOL_SIZE,
  76. max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
  77. pool_timeout=PGVECTOR_POOL_TIMEOUT,
  78. pool_recycle=PGVECTOR_POOL_RECYCLE,
  79. pool_pre_ping=True,
  80. poolclass=QueuePool,
  81. )
  82. else:
  83. engine = create_engine(
  84. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  85. )
  86. else:
  87. engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
  88. SessionLocal = sessionmaker(
  89. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  90. )
  91. self.session = scoped_session(SessionLocal)
  92. try:
  93. # Ensure the pgvector extension is available
  94. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  95. self.session.execute(
  96. text(
  97. """
  98. DO $$
  99. BEGIN
  100. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
  101. CREATE EXTENSION IF NOT EXISTS vector;
  102. END IF;
  103. END $$;
  104. """
  105. )
  106. )
  107. if PGVECTOR_PGCRYPTO:
  108. # Ensure the pgcrypto extension is available for encryption
  109. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  110. self.session.execute(
  111. text(
  112. """
  113. DO $$
  114. BEGIN
  115. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
  116. CREATE EXTENSION IF NOT EXISTS pgcrypto;
  117. END IF;
  118. END $$;
  119. """
  120. )
  121. )
  122. if not PGVECTOR_PGCRYPTO_KEY:
  123. raise ValueError(
  124. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  125. )
  126. # Check vector length consistency
  127. self.check_vector_length()
  128. # Create the tables if they do not exist
  129. # Base.metadata.create_all requires a bind (engine or connection)
  130. # Get the connection from the session
  131. connection = self.session.connection()
  132. Base.metadata.create_all(bind=connection)
  133. # Create an index on the vector column if it doesn't exist
  134. self.session.execute(
  135. text(
  136. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  137. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  138. )
  139. )
  140. self.session.execute(
  141. text(
  142. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  143. "ON document_chunk (collection_name);"
  144. )
  145. )
  146. self.session.commit()
  147. log.info("Initialization complete.")
  148. except Exception as e:
  149. self.session.rollback()
  150. log.exception(f"Error during initialization: {e}")
  151. raise
  152. def check_vector_length(self) -> None:
  153. """
  154. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  155. Raises an exception if there is a mismatch.
  156. """
  157. metadata = MetaData()
  158. try:
  159. # Attempt to reflect the 'document_chunk' table
  160. document_chunk_table = Table(
  161. "document_chunk", metadata, autoload_with=self.session.bind
  162. )
  163. except NoSuchTableError:
  164. # Table does not exist; no action needed
  165. return
  166. # Proceed to check the vector column
  167. if "vector" in document_chunk_table.columns:
  168. vector_column = document_chunk_table.columns["vector"]
  169. vector_type = vector_column.type
  170. if isinstance(vector_type, Vector):
  171. db_vector_length = vector_type.dim
  172. if db_vector_length != VECTOR_LENGTH:
  173. raise Exception(
  174. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  175. "Cannot change vector size after initialization without migrating the data."
  176. )
  177. else:
  178. raise Exception(
  179. "The 'vector' column exists but is not of type 'Vector'."
  180. )
  181. else:
  182. raise Exception(
  183. "The 'vector' column does not exist in the 'document_chunk' table."
  184. )
  185. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  186. # Adjust vector to have length VECTOR_LENGTH
  187. current_length = len(vector)
  188. if current_length < VECTOR_LENGTH:
  189. # Pad the vector with zeros
  190. vector += [0.0] * (VECTOR_LENGTH - current_length)
  191. elif current_length > VECTOR_LENGTH:
  192. # Truncate the vector to VECTOR_LENGTH
  193. vector = vector[:VECTOR_LENGTH]
  194. return vector
  195. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  196. try:
  197. if PGVECTOR_PGCRYPTO:
  198. for item in items:
  199. vector = self.adjust_vector_length(item["vector"])
  200. # Use raw SQL for BYTEA/pgcrypto
  201. # Ensure metadata is converted to its JSON text representation
  202. json_metadata = json.dumps(item["metadata"])
  203. self.session.execute(
  204. text(
  205. """
  206. INSERT INTO document_chunk
  207. (id, vector, collection_name, text, vmetadata)
  208. VALUES (
  209. :id, :vector, :collection_name,
  210. pgp_sym_encrypt(:text, :key),
  211. pgp_sym_encrypt(:metadata_text, :key)
  212. )
  213. ON CONFLICT (id) DO NOTHING
  214. """
  215. ),
  216. {
  217. "id": item["id"],
  218. "vector": vector,
  219. "collection_name": collection_name,
  220. "text": item["text"],
  221. "metadata_text": json_metadata,
  222. "key": PGVECTOR_PGCRYPTO_KEY,
  223. },
  224. )
  225. self.session.commit()
  226. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  227. else:
  228. new_items = []
  229. for item in items:
  230. vector = self.adjust_vector_length(item["vector"])
  231. new_chunk = DocumentChunk(
  232. id=item["id"],
  233. vector=vector,
  234. collection_name=collection_name,
  235. text=item["text"],
  236. vmetadata=stringify_metadata(item["metadata"]),
  237. )
  238. new_items.append(new_chunk)
  239. self.session.bulk_save_objects(new_items)
  240. self.session.commit()
  241. log.info(
  242. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  243. )
  244. except Exception as e:
  245. self.session.rollback()
  246. log.exception(f"Error during insert: {e}")
  247. raise
  248. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  249. try:
  250. if PGVECTOR_PGCRYPTO:
  251. for item in items:
  252. vector = self.adjust_vector_length(item["vector"])
  253. json_metadata = json.dumps(item["metadata"])
  254. self.session.execute(
  255. text(
  256. """
  257. INSERT INTO document_chunk
  258. (id, vector, collection_name, text, vmetadata)
  259. VALUES (
  260. :id, :vector, :collection_name,
  261. pgp_sym_encrypt(:text, :key),
  262. pgp_sym_encrypt(:metadata_text, :key)
  263. )
  264. ON CONFLICT (id) DO UPDATE SET
  265. vector = EXCLUDED.vector,
  266. collection_name = EXCLUDED.collection_name,
  267. text = EXCLUDED.text,
  268. vmetadata = EXCLUDED.vmetadata
  269. """
  270. ),
  271. {
  272. "id": item["id"],
  273. "vector": vector,
  274. "collection_name": collection_name,
  275. "text": item["text"],
  276. "metadata_text": json_metadata,
  277. "key": PGVECTOR_PGCRYPTO_KEY,
  278. },
  279. )
  280. self.session.commit()
  281. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  282. else:
  283. for item in items:
  284. vector = self.adjust_vector_length(item["vector"])
  285. existing = (
  286. self.session.query(DocumentChunk)
  287. .filter(DocumentChunk.id == item["id"])
  288. .first()
  289. )
  290. if existing:
  291. existing.vector = vector
  292. existing.text = item["text"]
  293. existing.vmetadata = stringify_metadata(item["metadata"])
  294. existing.collection_name = (
  295. collection_name # Update collection_name if necessary
  296. )
  297. else:
  298. new_chunk = DocumentChunk(
  299. id=item["id"],
  300. vector=vector,
  301. collection_name=collection_name,
  302. text=item["text"],
  303. vmetadata=stringify_metadata(item["metadata"]),
  304. )
  305. self.session.add(new_chunk)
  306. self.session.commit()
  307. log.info(
  308. f"Upserted {len(items)} items into collection '{collection_name}'."
  309. )
  310. except Exception as e:
  311. self.session.rollback()
  312. log.exception(f"Error during upsert: {e}")
  313. raise
  314. def search(
  315. self,
  316. collection_name: str,
  317. vectors: List[List[float]],
  318. limit: Optional[int] = None,
  319. ) -> Optional[SearchResult]:
  320. try:
  321. if not vectors:
  322. return None
  323. # Adjust query vectors to VECTOR_LENGTH
  324. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  325. num_queries = len(vectors)
  326. def vector_expr(vector):
  327. return cast(array(vector), Vector(VECTOR_LENGTH))
  328. # Create the values for query vectors
  329. qid_col = column("qid", Integer)
  330. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  331. query_vectors = (
  332. values(qid_col, q_vector_col)
  333. .data(
  334. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  335. )
  336. .alias("query_vectors")
  337. )
  338. result_fields = [
  339. DocumentChunk.id,
  340. ]
  341. if PGVECTOR_PGCRYPTO:
  342. result_fields.append(
  343. pgcrypto_decrypt(
  344. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  345. ).label("text")
  346. )
  347. result_fields.append(
  348. pgcrypto_decrypt(
  349. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  350. ).label("vmetadata")
  351. )
  352. else:
  353. result_fields.append(DocumentChunk.text)
  354. result_fields.append(DocumentChunk.vmetadata)
  355. result_fields.append(
  356. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  357. "distance"
  358. )
  359. )
  360. # Build the lateral subquery for each query vector
  361. subq = (
  362. select(*result_fields)
  363. .where(DocumentChunk.collection_name == collection_name)
  364. .order_by(
  365. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  366. )
  367. )
  368. if limit is not None:
  369. subq = subq.limit(limit)
  370. subq = subq.lateral("result")
  371. # Build the main query by joining query_vectors and the lateral subquery
  372. stmt = (
  373. select(
  374. query_vectors.c.qid,
  375. subq.c.id,
  376. subq.c.text,
  377. subq.c.vmetadata,
  378. subq.c.distance,
  379. )
  380. .select_from(query_vectors)
  381. .join(subq, true())
  382. .order_by(query_vectors.c.qid, subq.c.distance)
  383. )
  384. result_proxy = self.session.execute(stmt)
  385. results = result_proxy.all()
  386. ids = [[] for _ in range(num_queries)]
  387. distances = [[] for _ in range(num_queries)]
  388. documents = [[] for _ in range(num_queries)]
  389. metadatas = [[] for _ in range(num_queries)]
  390. if not results:
  391. return SearchResult(
  392. ids=ids,
  393. distances=distances,
  394. documents=documents,
  395. metadatas=metadatas,
  396. )
  397. for row in results:
  398. qid = int(row.qid)
  399. ids[qid].append(row.id)
  400. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  401. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  402. distances[qid].append((2.0 - row.distance) / 2.0)
  403. documents[qid].append(row.text)
  404. metadatas[qid].append(row.vmetadata)
  405. self.session.rollback() # read-only transaction
  406. return SearchResult(
  407. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  408. )
  409. except Exception as e:
  410. self.session.rollback()
  411. log.exception(f"Error during search: {e}")
  412. return None
  413. def query(
  414. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  415. ) -> Optional[GetResult]:
  416. try:
  417. if PGVECTOR_PGCRYPTO:
  418. # Build where clause for vmetadata filter
  419. where_clauses = [DocumentChunk.collection_name == collection_name]
  420. for key, value in filter.items():
  421. # decrypt then check key: JSON filter after decryption
  422. where_clauses.append(
  423. pgcrypto_decrypt(
  424. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  425. )[key].astext
  426. == str(value)
  427. )
  428. stmt = select(
  429. DocumentChunk.id,
  430. pgcrypto_decrypt(
  431. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  432. ).label("text"),
  433. pgcrypto_decrypt(
  434. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  435. ).label("vmetadata"),
  436. ).where(*where_clauses)
  437. if limit is not None:
  438. stmt = stmt.limit(limit)
  439. results = self.session.execute(stmt).all()
  440. else:
  441. query = self.session.query(DocumentChunk).filter(
  442. DocumentChunk.collection_name == collection_name
  443. )
  444. for key, value in filter.items():
  445. query = query.filter(
  446. DocumentChunk.vmetadata[key].astext == str(value)
  447. )
  448. if limit is not None:
  449. query = query.limit(limit)
  450. results = query.all()
  451. if not results:
  452. return None
  453. ids = [[result.id for result in results]]
  454. documents = [[result.text for result in results]]
  455. metadatas = [[result.vmetadata for result in results]]
  456. self.session.rollback() # read-only transaction
  457. return GetResult(
  458. ids=ids,
  459. documents=documents,
  460. metadatas=metadatas,
  461. )
  462. except Exception as e:
  463. self.session.rollback()
  464. log.exception(f"Error during query: {e}")
  465. return None
  466. def get(
  467. self, collection_name: str, limit: Optional[int] = None
  468. ) -> Optional[GetResult]:
  469. try:
  470. if PGVECTOR_PGCRYPTO:
  471. stmt = select(
  472. DocumentChunk.id,
  473. pgcrypto_decrypt(
  474. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  475. ).label("text"),
  476. pgcrypto_decrypt(
  477. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  478. ).label("vmetadata"),
  479. ).where(DocumentChunk.collection_name == collection_name)
  480. if limit is not None:
  481. stmt = stmt.limit(limit)
  482. results = self.session.execute(stmt).all()
  483. ids = [[row.id for row in results]]
  484. documents = [[row.text for row in results]]
  485. metadatas = [[row.vmetadata for row in results]]
  486. else:
  487. query = self.session.query(DocumentChunk).filter(
  488. DocumentChunk.collection_name == collection_name
  489. )
  490. if limit is not None:
  491. query = query.limit(limit)
  492. results = query.all()
  493. if not results:
  494. return None
  495. ids = [[result.id for result in results]]
  496. documents = [[result.text for result in results]]
  497. metadatas = [[result.vmetadata for result in results]]
  498. self.session.rollback() # read-only transaction
  499. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  500. except Exception as e:
  501. self.session.rollback()
  502. log.exception(f"Error during get: {e}")
  503. return None
  504. def delete(
  505. self,
  506. collection_name: str,
  507. ids: Optional[List[str]] = None,
  508. filter: Optional[Dict[str, Any]] = None,
  509. ) -> None:
  510. try:
  511. if PGVECTOR_PGCRYPTO:
  512. wheres = [DocumentChunk.collection_name == collection_name]
  513. if ids:
  514. wheres.append(DocumentChunk.id.in_(ids))
  515. if filter:
  516. for key, value in filter.items():
  517. wheres.append(
  518. pgcrypto_decrypt(
  519. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  520. )[key].astext
  521. == str(value)
  522. )
  523. stmt = DocumentChunk.__table__.delete().where(*wheres)
  524. result = self.session.execute(stmt)
  525. deleted = result.rowcount
  526. else:
  527. query = self.session.query(DocumentChunk).filter(
  528. DocumentChunk.collection_name == collection_name
  529. )
  530. if ids:
  531. query = query.filter(DocumentChunk.id.in_(ids))
  532. if filter:
  533. for key, value in filter.items():
  534. query = query.filter(
  535. DocumentChunk.vmetadata[key].astext == str(value)
  536. )
  537. deleted = query.delete(synchronize_session=False)
  538. self.session.commit()
  539. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  540. except Exception as e:
  541. self.session.rollback()
  542. log.exception(f"Error during delete: {e}")
  543. raise
  544. def reset(self) -> None:
  545. try:
  546. deleted = self.session.query(DocumentChunk).delete()
  547. self.session.commit()
  548. log.info(
  549. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  550. )
  551. except Exception as e:
  552. self.session.rollback()
  553. log.exception(f"Error during reset: {e}")
  554. raise
  555. def close(self) -> None:
  556. pass
  557. def has_collection(self, collection_name: str) -> bool:
  558. try:
  559. exists = (
  560. self.session.query(DocumentChunk)
  561. .filter(DocumentChunk.collection_name == collection_name)
  562. .first()
  563. is not None
  564. )
  565. self.session.rollback() # read-only transaction
  566. return exists
  567. except Exception as e:
  568. self.session.rollback()
  569. log.exception(f"Error checking collection existence: {e}")
  570. return False
  571. def delete_collection(self, collection_name: str) -> None:
  572. self.delete(collection_name)
  573. log.info(f"Collection '{collection_name}' deleted.")