pgvector.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool, QueuePool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.utils import stringify_metadata
  28. from open_webui.retrieval.vector.main import (
  29. VectorDBBase,
  30. VectorItem,
  31. SearchResult,
  32. GetResult,
  33. )
  34. from open_webui.config import (
  35. PGVECTOR_DB_URL,
  36. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  37. PGVECTOR_PGCRYPTO,
  38. PGVECTOR_PGCRYPTO_KEY,
  39. PGVECTOR_POOL_SIZE,
  40. PGVECTOR_POOL_MAX_OVERFLOW,
  41. PGVECTOR_POOL_TIMEOUT,
  42. PGVECTOR_POOL_RECYCLE,
  43. )
  44. from open_webui.env import SRC_LOG_LEVELS
  45. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  46. Base = declarative_base()
  47. log = logging.getLogger(__name__)
  48. log.setLevel(SRC_LOG_LEVELS["RAG"])
  49. def pgcrypto_encrypt(val, key):
  50. return func.pgp_sym_encrypt(val, literal(key))
  51. def pgcrypto_decrypt(col, key, outtype="text"):
  52. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  53. class DocumentChunk(Base):
  54. __tablename__ = "document_chunk"
  55. id = Column(Text, primary_key=True)
  56. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  57. collection_name = Column(Text, nullable=False)
  58. if PGVECTOR_PGCRYPTO:
  59. text = Column(LargeBinary, nullable=True)
  60. vmetadata = Column(LargeBinary, nullable=True)
  61. else:
  62. text = Column(Text, nullable=True)
  63. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  64. class PgvectorClient(VectorDBBase):
  65. def __init__(self) -> None:
  66. # if no pgvector uri, use the existing database connection
  67. if not PGVECTOR_DB_URL:
  68. from open_webui.internal.db import Session
  69. self.session = Session
  70. else:
  71. if isinstance(PGVECTOR_POOL_SIZE, int):
  72. if PGVECTOR_POOL_SIZE > 0:
  73. engine = create_engine(
  74. PGVECTOR_DB_URL,
  75. pool_size=PGVECTOR_POOL_SIZE,
  76. max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
  77. pool_timeout=PGVECTOR_POOL_TIMEOUT,
  78. pool_recycle=PGVECTOR_POOL_RECYCLE,
  79. pool_pre_ping=True,
  80. poolclass=QueuePool,
  81. )
  82. else:
  83. engine = create_engine(
  84. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  85. )
  86. else:
  87. engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
  88. SessionLocal = sessionmaker(
  89. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  90. )
  91. self.session = scoped_session(SessionLocal)
  92. try:
  93. # Ensure the pgvector extension is available
  94. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  95. self.session.execute(text("""
  96. DO $$
  97. BEGIN
  98. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
  99. CREATE EXTENSION IF NOT EXISTS vector;
  100. END IF;
  101. END $$;
  102. """))
  103. if PGVECTOR_PGCRYPTO:
  104. # Ensure the pgcrypto extension is available for encryption
  105. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  106. self.session.execute(text("""
  107. DO $$
  108. BEGIN
  109. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
  110. CREATE EXTENSION IF NOT EXISTS pgcrypto;
  111. END IF;
  112. END $$;
  113. """))
  114. if not PGVECTOR_PGCRYPTO_KEY:
  115. raise ValueError(
  116. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  117. )
  118. # Check vector length consistency
  119. self.check_vector_length()
  120. # Create the tables if they do not exist
  121. # Base.metadata.create_all requires a bind (engine or connection)
  122. # Get the connection from the session
  123. connection = self.session.connection()
  124. Base.metadata.create_all(bind=connection)
  125. # Create an index on the vector column if it doesn't exist
  126. self.session.execute(
  127. text(
  128. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  129. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  130. )
  131. )
  132. self.session.execute(
  133. text(
  134. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  135. "ON document_chunk (collection_name);"
  136. )
  137. )
  138. self.session.commit()
  139. log.info("Initialization complete.")
  140. except Exception as e:
  141. self.session.rollback()
  142. log.exception(f"Error during initialization: {e}")
  143. raise
  144. def check_vector_length(self) -> None:
  145. """
  146. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  147. Raises an exception if there is a mismatch.
  148. """
  149. metadata = MetaData()
  150. try:
  151. # Attempt to reflect the 'document_chunk' table
  152. document_chunk_table = Table(
  153. "document_chunk", metadata, autoload_with=self.session.bind
  154. )
  155. except NoSuchTableError:
  156. # Table does not exist; no action needed
  157. return
  158. # Proceed to check the vector column
  159. if "vector" in document_chunk_table.columns:
  160. vector_column = document_chunk_table.columns["vector"]
  161. vector_type = vector_column.type
  162. if isinstance(vector_type, Vector):
  163. db_vector_length = vector_type.dim
  164. if db_vector_length != VECTOR_LENGTH:
  165. raise Exception(
  166. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  167. "Cannot change vector size after initialization without migrating the data."
  168. )
  169. else:
  170. raise Exception(
  171. "The 'vector' column exists but is not of type 'Vector'."
  172. )
  173. else:
  174. raise Exception(
  175. "The 'vector' column does not exist in the 'document_chunk' table."
  176. )
  177. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  178. # Adjust vector to have length VECTOR_LENGTH
  179. current_length = len(vector)
  180. if current_length < VECTOR_LENGTH:
  181. # Pad the vector with zeros
  182. vector += [0.0] * (VECTOR_LENGTH - current_length)
  183. elif current_length > VECTOR_LENGTH:
  184. # Truncate the vector to VECTOR_LENGTH
  185. vector = vector[:VECTOR_LENGTH]
  186. return vector
  187. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  188. try:
  189. if PGVECTOR_PGCRYPTO:
  190. for item in items:
  191. vector = self.adjust_vector_length(item["vector"])
  192. # Use raw SQL for BYTEA/pgcrypto
  193. # Ensure metadata is converted to its JSON text representation
  194. json_metadata = json.dumps(item["metadata"])
  195. self.session.execute(
  196. text(
  197. """
  198. INSERT INTO document_chunk
  199. (id, vector, collection_name, text, vmetadata)
  200. VALUES (
  201. :id, :vector, :collection_name,
  202. pgp_sym_encrypt(:text, :key),
  203. pgp_sym_encrypt(:metadata_text, :key)
  204. )
  205. ON CONFLICT (id) DO NOTHING
  206. """
  207. ),
  208. {
  209. "id": item["id"],
  210. "vector": vector,
  211. "collection_name": collection_name,
  212. "text": item["text"],
  213. "metadata_text": json_metadata,
  214. "key": PGVECTOR_PGCRYPTO_KEY,
  215. },
  216. )
  217. self.session.commit()
  218. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  219. else:
  220. new_items = []
  221. for item in items:
  222. vector = self.adjust_vector_length(item["vector"])
  223. new_chunk = DocumentChunk(
  224. id=item["id"],
  225. vector=vector,
  226. collection_name=collection_name,
  227. text=item["text"],
  228. vmetadata=stringify_metadata(item["metadata"]),
  229. )
  230. new_items.append(new_chunk)
  231. self.session.bulk_save_objects(new_items)
  232. self.session.commit()
  233. log.info(
  234. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  235. )
  236. except Exception as e:
  237. self.session.rollback()
  238. log.exception(f"Error during insert: {e}")
  239. raise
  240. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  241. try:
  242. if PGVECTOR_PGCRYPTO:
  243. for item in items:
  244. vector = self.adjust_vector_length(item["vector"])
  245. json_metadata = json.dumps(item["metadata"])
  246. self.session.execute(
  247. text(
  248. """
  249. INSERT INTO document_chunk
  250. (id, vector, collection_name, text, vmetadata)
  251. VALUES (
  252. :id, :vector, :collection_name,
  253. pgp_sym_encrypt(:text, :key),
  254. pgp_sym_encrypt(:metadata_text, :key)
  255. )
  256. ON CONFLICT (id) DO UPDATE SET
  257. vector = EXCLUDED.vector,
  258. collection_name = EXCLUDED.collection_name,
  259. text = EXCLUDED.text,
  260. vmetadata = EXCLUDED.vmetadata
  261. """
  262. ),
  263. {
  264. "id": item["id"],
  265. "vector": vector,
  266. "collection_name": collection_name,
  267. "text": item["text"],
  268. "metadata_text": json_metadata,
  269. "key": PGVECTOR_PGCRYPTO_KEY,
  270. },
  271. )
  272. self.session.commit()
  273. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  274. else:
  275. for item in items:
  276. vector = self.adjust_vector_length(item["vector"])
  277. existing = (
  278. self.session.query(DocumentChunk)
  279. .filter(DocumentChunk.id == item["id"])
  280. .first()
  281. )
  282. if existing:
  283. existing.vector = vector
  284. existing.text = item["text"]
  285. existing.vmetadata = stringify_metadata(item["metadata"])
  286. existing.collection_name = (
  287. collection_name # Update collection_name if necessary
  288. )
  289. else:
  290. new_chunk = DocumentChunk(
  291. id=item["id"],
  292. vector=vector,
  293. collection_name=collection_name,
  294. text=item["text"],
  295. vmetadata=stringify_metadata(item["metadata"]),
  296. )
  297. self.session.add(new_chunk)
  298. self.session.commit()
  299. log.info(
  300. f"Upserted {len(items)} items into collection '{collection_name}'."
  301. )
  302. except Exception as e:
  303. self.session.rollback()
  304. log.exception(f"Error during upsert: {e}")
  305. raise
  306. def search(
  307. self,
  308. collection_name: str,
  309. vectors: List[List[float]],
  310. limit: Optional[int] = None,
  311. ) -> Optional[SearchResult]:
  312. try:
  313. if not vectors:
  314. return None
  315. # Adjust query vectors to VECTOR_LENGTH
  316. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  317. num_queries = len(vectors)
  318. def vector_expr(vector):
  319. return cast(array(vector), Vector(VECTOR_LENGTH))
  320. # Create the values for query vectors
  321. qid_col = column("qid", Integer)
  322. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  323. query_vectors = (
  324. values(qid_col, q_vector_col)
  325. .data(
  326. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  327. )
  328. .alias("query_vectors")
  329. )
  330. result_fields = [
  331. DocumentChunk.id,
  332. ]
  333. if PGVECTOR_PGCRYPTO:
  334. result_fields.append(
  335. pgcrypto_decrypt(
  336. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  337. ).label("text")
  338. )
  339. result_fields.append(
  340. pgcrypto_decrypt(
  341. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  342. ).label("vmetadata")
  343. )
  344. else:
  345. result_fields.append(DocumentChunk.text)
  346. result_fields.append(DocumentChunk.vmetadata)
  347. result_fields.append(
  348. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  349. "distance"
  350. )
  351. )
  352. # Build the lateral subquery for each query vector
  353. subq = (
  354. select(*result_fields)
  355. .where(DocumentChunk.collection_name == collection_name)
  356. .order_by(
  357. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  358. )
  359. )
  360. if limit is not None:
  361. subq = subq.limit(limit)
  362. subq = subq.lateral("result")
  363. # Build the main query by joining query_vectors and the lateral subquery
  364. stmt = (
  365. select(
  366. query_vectors.c.qid,
  367. subq.c.id,
  368. subq.c.text,
  369. subq.c.vmetadata,
  370. subq.c.distance,
  371. )
  372. .select_from(query_vectors)
  373. .join(subq, true())
  374. .order_by(query_vectors.c.qid, subq.c.distance)
  375. )
  376. result_proxy = self.session.execute(stmt)
  377. results = result_proxy.all()
  378. ids = [[] for _ in range(num_queries)]
  379. distances = [[] for _ in range(num_queries)]
  380. documents = [[] for _ in range(num_queries)]
  381. metadatas = [[] for _ in range(num_queries)]
  382. if not results:
  383. return SearchResult(
  384. ids=ids,
  385. distances=distances,
  386. documents=documents,
  387. metadatas=metadatas,
  388. )
  389. for row in results:
  390. qid = int(row.qid)
  391. ids[qid].append(row.id)
  392. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  393. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  394. distances[qid].append((2.0 - row.distance) / 2.0)
  395. documents[qid].append(row.text)
  396. metadatas[qid].append(row.vmetadata)
  397. self.session.rollback() # read-only transaction
  398. return SearchResult(
  399. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  400. )
  401. except Exception as e:
  402. self.session.rollback()
  403. log.exception(f"Error during search: {e}")
  404. return None
  405. def query(
  406. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  407. ) -> Optional[GetResult]:
  408. try:
  409. if PGVECTOR_PGCRYPTO:
  410. # Build where clause for vmetadata filter
  411. where_clauses = [DocumentChunk.collection_name == collection_name]
  412. for key, value in filter.items():
  413. # decrypt then check key: JSON filter after decryption
  414. where_clauses.append(
  415. pgcrypto_decrypt(
  416. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  417. )[key].astext
  418. == str(value)
  419. )
  420. stmt = select(
  421. DocumentChunk.id,
  422. pgcrypto_decrypt(
  423. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  424. ).label("text"),
  425. pgcrypto_decrypt(
  426. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  427. ).label("vmetadata"),
  428. ).where(*where_clauses)
  429. if limit is not None:
  430. stmt = stmt.limit(limit)
  431. results = self.session.execute(stmt).all()
  432. else:
  433. query = self.session.query(DocumentChunk).filter(
  434. DocumentChunk.collection_name == collection_name
  435. )
  436. for key, value in filter.items():
  437. query = query.filter(
  438. DocumentChunk.vmetadata[key].astext == str(value)
  439. )
  440. if limit is not None:
  441. query = query.limit(limit)
  442. results = query.all()
  443. if not results:
  444. return None
  445. ids = [[result.id for result in results]]
  446. documents = [[result.text for result in results]]
  447. metadatas = [[result.vmetadata for result in results]]
  448. self.session.rollback() # read-only transaction
  449. return GetResult(
  450. ids=ids,
  451. documents=documents,
  452. metadatas=metadatas,
  453. )
  454. except Exception as e:
  455. self.session.rollback()
  456. log.exception(f"Error during query: {e}")
  457. return None
  458. def get(
  459. self, collection_name: str, limit: Optional[int] = None
  460. ) -> Optional[GetResult]:
  461. try:
  462. if PGVECTOR_PGCRYPTO:
  463. stmt = select(
  464. DocumentChunk.id,
  465. pgcrypto_decrypt(
  466. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  467. ).label("text"),
  468. pgcrypto_decrypt(
  469. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  470. ).label("vmetadata"),
  471. ).where(DocumentChunk.collection_name == collection_name)
  472. if limit is not None:
  473. stmt = stmt.limit(limit)
  474. results = self.session.execute(stmt).all()
  475. ids = [[row.id for row in results]]
  476. documents = [[row.text for row in results]]
  477. metadatas = [[row.vmetadata for row in results]]
  478. else:
  479. query = self.session.query(DocumentChunk).filter(
  480. DocumentChunk.collection_name == collection_name
  481. )
  482. if limit is not None:
  483. query = query.limit(limit)
  484. results = query.all()
  485. if not results:
  486. return None
  487. ids = [[result.id for result in results]]
  488. documents = [[result.text for result in results]]
  489. metadatas = [[result.vmetadata for result in results]]
  490. self.session.rollback() # read-only transaction
  491. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  492. except Exception as e:
  493. self.session.rollback()
  494. log.exception(f"Error during get: {e}")
  495. return None
  496. def delete(
  497. self,
  498. collection_name: str,
  499. ids: Optional[List[str]] = None,
  500. filter: Optional[Dict[str, Any]] = None,
  501. ) -> None:
  502. try:
  503. if PGVECTOR_PGCRYPTO:
  504. wheres = [DocumentChunk.collection_name == collection_name]
  505. if ids:
  506. wheres.append(DocumentChunk.id.in_(ids))
  507. if filter:
  508. for key, value in filter.items():
  509. wheres.append(
  510. pgcrypto_decrypt(
  511. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  512. )[key].astext
  513. == str(value)
  514. )
  515. stmt = DocumentChunk.__table__.delete().where(*wheres)
  516. result = self.session.execute(stmt)
  517. deleted = result.rowcount
  518. else:
  519. query = self.session.query(DocumentChunk).filter(
  520. DocumentChunk.collection_name == collection_name
  521. )
  522. if ids:
  523. query = query.filter(DocumentChunk.id.in_(ids))
  524. if filter:
  525. for key, value in filter.items():
  526. query = query.filter(
  527. DocumentChunk.vmetadata[key].astext == str(value)
  528. )
  529. deleted = query.delete(synchronize_session=False)
  530. self.session.commit()
  531. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  532. except Exception as e:
  533. self.session.rollback()
  534. log.exception(f"Error during delete: {e}")
  535. raise
  536. def reset(self) -> None:
  537. try:
  538. deleted = self.session.query(DocumentChunk).delete()
  539. self.session.commit()
  540. log.info(
  541. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  542. )
  543. except Exception as e:
  544. self.session.rollback()
  545. log.exception(f"Error during reset: {e}")
  546. raise
  547. def close(self) -> None:
  548. pass
  549. def has_collection(self, collection_name: str) -> bool:
  550. try:
  551. exists = (
  552. self.session.query(DocumentChunk)
  553. .filter(DocumentChunk.collection_name == collection_name)
  554. .first()
  555. is not None
  556. )
  557. self.session.rollback() # read-only transaction
  558. return exists
  559. except Exception as e:
  560. self.session.rollback()
  561. log.exception(f"Error checking collection existence: {e}")
  562. return False
  563. def delete_collection(self, collection_name: str) -> None:
  564. self.delete(collection_name)
  565. log.info(f"Collection '{collection_name}' deleted.")