1
0

pgvector.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool, QueuePool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.utils import process_metadata
  28. from open_webui.retrieval.vector.main import (
  29. VectorDBBase,
  30. VectorItem,
  31. SearchResult,
  32. GetResult,
  33. )
  34. from open_webui.config import (
  35. PGVECTOR_DB_URL,
  36. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  37. PGVECTOR_CREATE_EXTENSION,
  38. PGVECTOR_PGCRYPTO,
  39. PGVECTOR_PGCRYPTO_KEY,
  40. PGVECTOR_POOL_SIZE,
  41. PGVECTOR_POOL_MAX_OVERFLOW,
  42. PGVECTOR_POOL_TIMEOUT,
  43. PGVECTOR_POOL_RECYCLE,
  44. )
  45. from open_webui.env import SRC_LOG_LEVELS
  46. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  47. Base = declarative_base()
  48. log = logging.getLogger(__name__)
  49. log.setLevel(SRC_LOG_LEVELS["RAG"])
  50. def pgcrypto_encrypt(val, key):
  51. return func.pgp_sym_encrypt(val, literal(key))
  52. def pgcrypto_decrypt(col, key, outtype="text"):
  53. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  54. class DocumentChunk(Base):
  55. __tablename__ = "document_chunk"
  56. id = Column(Text, primary_key=True)
  57. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  58. collection_name = Column(Text, nullable=False)
  59. if PGVECTOR_PGCRYPTO:
  60. text = Column(LargeBinary, nullable=True)
  61. vmetadata = Column(LargeBinary, nullable=True)
  62. else:
  63. text = Column(Text, nullable=True)
  64. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  65. class PgvectorClient(VectorDBBase):
  66. def __init__(self) -> None:
  67. # if no pgvector uri, use the existing database connection
  68. if not PGVECTOR_DB_URL:
  69. from open_webui.internal.db import Session
  70. self.session = Session
  71. else:
  72. if isinstance(PGVECTOR_POOL_SIZE, int):
  73. if PGVECTOR_POOL_SIZE > 0:
  74. engine = create_engine(
  75. PGVECTOR_DB_URL,
  76. pool_size=PGVECTOR_POOL_SIZE,
  77. max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
  78. pool_timeout=PGVECTOR_POOL_TIMEOUT,
  79. pool_recycle=PGVECTOR_POOL_RECYCLE,
  80. pool_pre_ping=True,
  81. poolclass=QueuePool,
  82. )
  83. else:
  84. engine = create_engine(
  85. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  86. )
  87. else:
  88. engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
  89. SessionLocal = sessionmaker(
  90. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  91. )
  92. self.session = scoped_session(SessionLocal)
  93. try:
  94. # Ensure the pgvector extension is available
  95. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  96. if PGVECTOR_CREATE_EXTENSION:
  97. self.session.execute(
  98. text(
  99. """
  100. DO $$
  101. BEGIN
  102. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
  103. CREATE EXTENSION IF NOT EXISTS vector;
  104. END IF;
  105. END $$;
  106. """
  107. )
  108. )
  109. if PGVECTOR_PGCRYPTO:
  110. # Ensure the pgcrypto extension is available for encryption
  111. # Use a conditional check to avoid permission issues on Azure PostgreSQL
  112. self.session.execute(
  113. text(
  114. """
  115. DO $$
  116. BEGIN
  117. IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
  118. CREATE EXTENSION IF NOT EXISTS pgcrypto;
  119. END IF;
  120. END $$;
  121. """
  122. )
  123. )
  124. if not PGVECTOR_PGCRYPTO_KEY:
  125. raise ValueError(
  126. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  127. )
  128. # Check vector length consistency
  129. self.check_vector_length()
  130. # Create the tables if they do not exist
  131. # Base.metadata.create_all requires a bind (engine or connection)
  132. # Get the connection from the session
  133. connection = self.session.connection()
  134. Base.metadata.create_all(bind=connection)
  135. # Create an index on the vector column if it doesn't exist
  136. self.session.execute(
  137. text(
  138. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  139. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  140. )
  141. )
  142. self.session.execute(
  143. text(
  144. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  145. "ON document_chunk (collection_name);"
  146. )
  147. )
  148. self.session.commit()
  149. log.info("Initialization complete.")
  150. except Exception as e:
  151. self.session.rollback()
  152. log.exception(f"Error during initialization: {e}")
  153. raise
  154. def check_vector_length(self) -> None:
  155. """
  156. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  157. Raises an exception if there is a mismatch.
  158. """
  159. metadata = MetaData()
  160. try:
  161. # Attempt to reflect the 'document_chunk' table
  162. document_chunk_table = Table(
  163. "document_chunk", metadata, autoload_with=self.session.bind
  164. )
  165. except NoSuchTableError:
  166. # Table does not exist; no action needed
  167. return
  168. # Proceed to check the vector column
  169. if "vector" in document_chunk_table.columns:
  170. vector_column = document_chunk_table.columns["vector"]
  171. vector_type = vector_column.type
  172. if isinstance(vector_type, Vector):
  173. db_vector_length = vector_type.dim
  174. if db_vector_length != VECTOR_LENGTH:
  175. raise Exception(
  176. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  177. "Cannot change vector size after initialization without migrating the data."
  178. )
  179. else:
  180. raise Exception(
  181. "The 'vector' column exists but is not of type 'Vector'."
  182. )
  183. else:
  184. raise Exception(
  185. "The 'vector' column does not exist in the 'document_chunk' table."
  186. )
  187. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  188. # Adjust vector to have length VECTOR_LENGTH
  189. current_length = len(vector)
  190. if current_length < VECTOR_LENGTH:
  191. # Pad the vector with zeros
  192. vector += [0.0] * (VECTOR_LENGTH - current_length)
  193. elif current_length > VECTOR_LENGTH:
  194. # Truncate the vector to VECTOR_LENGTH
  195. vector = vector[:VECTOR_LENGTH]
  196. return vector
  197. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  198. try:
  199. if PGVECTOR_PGCRYPTO:
  200. for item in items:
  201. vector = self.adjust_vector_length(item["vector"])
  202. # Use raw SQL for BYTEA/pgcrypto
  203. # Ensure metadata is converted to its JSON text representation
  204. json_metadata = json.dumps(item["metadata"])
  205. self.session.execute(
  206. text(
  207. """
  208. INSERT INTO document_chunk
  209. (id, vector, collection_name, text, vmetadata)
  210. VALUES (
  211. :id, :vector, :collection_name,
  212. pgp_sym_encrypt(:text, :key),
  213. pgp_sym_encrypt(:metadata_text, :key)
  214. )
  215. ON CONFLICT (id) DO NOTHING
  216. """
  217. ),
  218. {
  219. "id": item["id"],
  220. "vector": vector,
  221. "collection_name": collection_name,
  222. "text": item["text"],
  223. "metadata_text": json_metadata,
  224. "key": PGVECTOR_PGCRYPTO_KEY,
  225. },
  226. )
  227. self.session.commit()
  228. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  229. else:
  230. new_items = []
  231. for item in items:
  232. vector = self.adjust_vector_length(item["vector"])
  233. new_chunk = DocumentChunk(
  234. id=item["id"],
  235. vector=vector,
  236. collection_name=collection_name,
  237. text=item["text"],
  238. vmetadata=process_metadata(item["metadata"]),
  239. )
  240. new_items.append(new_chunk)
  241. self.session.bulk_save_objects(new_items)
  242. self.session.commit()
  243. log.info(
  244. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  245. )
  246. except Exception as e:
  247. self.session.rollback()
  248. log.exception(f"Error during insert: {e}")
  249. raise
  250. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  251. try:
  252. if PGVECTOR_PGCRYPTO:
  253. for item in items:
  254. vector = self.adjust_vector_length(item["vector"])
  255. json_metadata = json.dumps(item["metadata"])
  256. self.session.execute(
  257. text(
  258. """
  259. INSERT INTO document_chunk
  260. (id, vector, collection_name, text, vmetadata)
  261. VALUES (
  262. :id, :vector, :collection_name,
  263. pgp_sym_encrypt(:text, :key),
  264. pgp_sym_encrypt(:metadata_text, :key)
  265. )
  266. ON CONFLICT (id) DO UPDATE SET
  267. vector = EXCLUDED.vector,
  268. collection_name = EXCLUDED.collection_name,
  269. text = EXCLUDED.text,
  270. vmetadata = EXCLUDED.vmetadata
  271. """
  272. ),
  273. {
  274. "id": item["id"],
  275. "vector": vector,
  276. "collection_name": collection_name,
  277. "text": item["text"],
  278. "metadata_text": json_metadata,
  279. "key": PGVECTOR_PGCRYPTO_KEY,
  280. },
  281. )
  282. self.session.commit()
  283. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  284. else:
  285. for item in items:
  286. vector = self.adjust_vector_length(item["vector"])
  287. existing = (
  288. self.session.query(DocumentChunk)
  289. .filter(DocumentChunk.id == item["id"])
  290. .first()
  291. )
  292. if existing:
  293. existing.vector = vector
  294. existing.text = item["text"]
  295. existing.vmetadata = process_metadata(item["metadata"])
  296. existing.collection_name = (
  297. collection_name # Update collection_name if necessary
  298. )
  299. else:
  300. new_chunk = DocumentChunk(
  301. id=item["id"],
  302. vector=vector,
  303. collection_name=collection_name,
  304. text=item["text"],
  305. vmetadata=process_metadata(item["metadata"]),
  306. )
  307. self.session.add(new_chunk)
  308. self.session.commit()
  309. log.info(
  310. f"Upserted {len(items)} items into collection '{collection_name}'."
  311. )
  312. except Exception as e:
  313. self.session.rollback()
  314. log.exception(f"Error during upsert: {e}")
  315. raise
  316. def search(
  317. self,
  318. collection_name: str,
  319. vectors: List[List[float]],
  320. limit: Optional[int] = None,
  321. ) -> Optional[SearchResult]:
  322. try:
  323. if not vectors:
  324. return None
  325. # Adjust query vectors to VECTOR_LENGTH
  326. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  327. num_queries = len(vectors)
  328. def vector_expr(vector):
  329. return cast(array(vector), Vector(VECTOR_LENGTH))
  330. # Create the values for query vectors
  331. qid_col = column("qid", Integer)
  332. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  333. query_vectors = (
  334. values(qid_col, q_vector_col)
  335. .data(
  336. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  337. )
  338. .alias("query_vectors")
  339. )
  340. result_fields = [
  341. DocumentChunk.id,
  342. ]
  343. if PGVECTOR_PGCRYPTO:
  344. result_fields.append(
  345. pgcrypto_decrypt(
  346. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  347. ).label("text")
  348. )
  349. result_fields.append(
  350. pgcrypto_decrypt(
  351. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  352. ).label("vmetadata")
  353. )
  354. else:
  355. result_fields.append(DocumentChunk.text)
  356. result_fields.append(DocumentChunk.vmetadata)
  357. result_fields.append(
  358. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  359. "distance"
  360. )
  361. )
  362. # Build the lateral subquery for each query vector
  363. subq = (
  364. select(*result_fields)
  365. .where(DocumentChunk.collection_name == collection_name)
  366. .order_by(
  367. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  368. )
  369. )
  370. if limit is not None:
  371. subq = subq.limit(limit)
  372. subq = subq.lateral("result")
  373. # Build the main query by joining query_vectors and the lateral subquery
  374. stmt = (
  375. select(
  376. query_vectors.c.qid,
  377. subq.c.id,
  378. subq.c.text,
  379. subq.c.vmetadata,
  380. subq.c.distance,
  381. )
  382. .select_from(query_vectors)
  383. .join(subq, true())
  384. .order_by(query_vectors.c.qid, subq.c.distance)
  385. )
  386. result_proxy = self.session.execute(stmt)
  387. results = result_proxy.all()
  388. ids = [[] for _ in range(num_queries)]
  389. distances = [[] for _ in range(num_queries)]
  390. documents = [[] for _ in range(num_queries)]
  391. metadatas = [[] for _ in range(num_queries)]
  392. if not results:
  393. return SearchResult(
  394. ids=ids,
  395. distances=distances,
  396. documents=documents,
  397. metadatas=metadatas,
  398. )
  399. for row in results:
  400. qid = int(row.qid)
  401. ids[qid].append(row.id)
  402. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  403. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  404. distances[qid].append((2.0 - row.distance) / 2.0)
  405. documents[qid].append(row.text)
  406. metadatas[qid].append(row.vmetadata)
  407. self.session.rollback() # read-only transaction
  408. return SearchResult(
  409. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  410. )
  411. except Exception as e:
  412. self.session.rollback()
  413. log.exception(f"Error during search: {e}")
  414. return None
  415. def query(
  416. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  417. ) -> Optional[GetResult]:
  418. try:
  419. if PGVECTOR_PGCRYPTO:
  420. # Build where clause for vmetadata filter
  421. where_clauses = [DocumentChunk.collection_name == collection_name]
  422. for key, value in filter.items():
  423. # decrypt then check key: JSON filter after decryption
  424. where_clauses.append(
  425. pgcrypto_decrypt(
  426. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  427. )[key].astext
  428. == str(value)
  429. )
  430. stmt = select(
  431. DocumentChunk.id,
  432. pgcrypto_decrypt(
  433. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  434. ).label("text"),
  435. pgcrypto_decrypt(
  436. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  437. ).label("vmetadata"),
  438. ).where(*where_clauses)
  439. if limit is not None:
  440. stmt = stmt.limit(limit)
  441. results = self.session.execute(stmt).all()
  442. else:
  443. query = self.session.query(DocumentChunk).filter(
  444. DocumentChunk.collection_name == collection_name
  445. )
  446. for key, value in filter.items():
  447. query = query.filter(
  448. DocumentChunk.vmetadata[key].astext == str(value)
  449. )
  450. if limit is not None:
  451. query = query.limit(limit)
  452. results = query.all()
  453. if not results:
  454. return None
  455. ids = [[result.id for result in results]]
  456. documents = [[result.text for result in results]]
  457. metadatas = [[result.vmetadata for result in results]]
  458. self.session.rollback() # read-only transaction
  459. return GetResult(
  460. ids=ids,
  461. documents=documents,
  462. metadatas=metadatas,
  463. )
  464. except Exception as e:
  465. self.session.rollback()
  466. log.exception(f"Error during query: {e}")
  467. return None
  468. def get(
  469. self, collection_name: str, limit: Optional[int] = None
  470. ) -> Optional[GetResult]:
  471. try:
  472. if PGVECTOR_PGCRYPTO:
  473. stmt = select(
  474. DocumentChunk.id,
  475. pgcrypto_decrypt(
  476. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  477. ).label("text"),
  478. pgcrypto_decrypt(
  479. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  480. ).label("vmetadata"),
  481. ).where(DocumentChunk.collection_name == collection_name)
  482. if limit is not None:
  483. stmt = stmt.limit(limit)
  484. results = self.session.execute(stmt).all()
  485. ids = [[row.id for row in results]]
  486. documents = [[row.text for row in results]]
  487. metadatas = [[row.vmetadata for row in results]]
  488. else:
  489. query = self.session.query(DocumentChunk).filter(
  490. DocumentChunk.collection_name == collection_name
  491. )
  492. if limit is not None:
  493. query = query.limit(limit)
  494. results = query.all()
  495. if not results:
  496. return None
  497. ids = [[result.id for result in results]]
  498. documents = [[result.text for result in results]]
  499. metadatas = [[result.vmetadata for result in results]]
  500. self.session.rollback() # read-only transaction
  501. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  502. except Exception as e:
  503. self.session.rollback()
  504. log.exception(f"Error during get: {e}")
  505. return None
  506. def delete(
  507. self,
  508. collection_name: str,
  509. ids: Optional[List[str]] = None,
  510. filter: Optional[Dict[str, Any]] = None,
  511. ) -> None:
  512. try:
  513. if PGVECTOR_PGCRYPTO:
  514. wheres = [DocumentChunk.collection_name == collection_name]
  515. if ids:
  516. wheres.append(DocumentChunk.id.in_(ids))
  517. if filter:
  518. for key, value in filter.items():
  519. wheres.append(
  520. pgcrypto_decrypt(
  521. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  522. )[key].astext
  523. == str(value)
  524. )
  525. stmt = DocumentChunk.__table__.delete().where(*wheres)
  526. result = self.session.execute(stmt)
  527. deleted = result.rowcount
  528. else:
  529. query = self.session.query(DocumentChunk).filter(
  530. DocumentChunk.collection_name == collection_name
  531. )
  532. if ids:
  533. query = query.filter(DocumentChunk.id.in_(ids))
  534. if filter:
  535. for key, value in filter.items():
  536. query = query.filter(
  537. DocumentChunk.vmetadata[key].astext == str(value)
  538. )
  539. deleted = query.delete(synchronize_session=False)
  540. self.session.commit()
  541. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  542. except Exception as e:
  543. self.session.rollback()
  544. log.exception(f"Error during delete: {e}")
  545. raise
  546. def reset(self) -> None:
  547. try:
  548. deleted = self.session.query(DocumentChunk).delete()
  549. self.session.commit()
  550. log.info(
  551. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  552. )
  553. except Exception as e:
  554. self.session.rollback()
  555. log.exception(f"Error during reset: {e}")
  556. raise
  557. def close(self) -> None:
  558. pass
  559. def has_collection(self, collection_name: str) -> bool:
  560. try:
  561. exists = (
  562. self.session.query(DocumentChunk)
  563. .filter(DocumentChunk.collection_name == collection_name)
  564. .first()
  565. is not None
  566. )
  567. self.session.rollback() # read-only transaction
  568. return exists
  569. except Exception as e:
  570. self.session.rollback()
  571. log.exception(f"Error checking collection existence: {e}")
  572. return False
  573. def delete_collection(self, collection_name: str) -> None:
  574. self.delete(collection_name)
  575. log.info(f"Collection '{collection_name}' deleted.")