pgvector.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool, QueuePool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.utils import stringify_metadata
  28. from open_webui.retrieval.vector.main import (
  29. VectorDBBase,
  30. VectorItem,
  31. SearchResult,
  32. GetResult,
  33. )
  34. from open_webui.config import (
  35. PGVECTOR_DB_URL,
  36. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  37. PGVECTOR_PGCRYPTO,
  38. PGVECTOR_PGCRYPTO_KEY,
  39. PGVECTOR_POOL_SIZE,
  40. PGVECTOR_POOL_MAX_OVERFLOW,
  41. PGVECTOR_POOL_TIMEOUT,
  42. PGVECTOR_POOL_RECYCLE,
  43. )
  44. from open_webui.env import SRC_LOG_LEVELS
  45. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  46. Base = declarative_base()
  47. log = logging.getLogger(__name__)
  48. log.setLevel(SRC_LOG_LEVELS["RAG"])
  49. def pgcrypto_encrypt(val, key):
  50. return func.pgp_sym_encrypt(val, literal(key))
  51. def pgcrypto_decrypt(col, key, outtype="text"):
  52. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  53. class DocumentChunk(Base):
  54. __tablename__ = "document_chunk"
  55. id = Column(Text, primary_key=True)
  56. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  57. collection_name = Column(Text, nullable=False)
  58. if PGVECTOR_PGCRYPTO:
  59. text = Column(LargeBinary, nullable=True)
  60. vmetadata = Column(LargeBinary, nullable=True)
  61. else:
  62. text = Column(Text, nullable=True)
  63. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  64. class PgvectorClient(VectorDBBase):
  65. def __init__(self) -> None:
  66. # if no pgvector uri, use the existing database connection
  67. if not PGVECTOR_DB_URL:
  68. from open_webui.internal.db import Session
  69. self.session = Session
  70. else:
  71. if isinstance(PGVECTOR_POOL_SIZE, int):
  72. if PGVECTOR_POOL_SIZE > 0:
  73. engine = create_engine(
  74. PGVECTOR_DB_URL,
  75. pool_size=PGVECTOR_POOL_SIZE,
  76. max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
  77. pool_timeout=PGVECTOR_POOL_TIMEOUT,
  78. pool_recycle=PGVECTOR_POOL_RECYCLE,
  79. pool_pre_ping=True,
  80. poolclass=QueuePool,
  81. )
  82. else:
  83. engine = create_engine(
  84. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  85. )
  86. else:
  87. engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
  88. SessionLocal = sessionmaker(
  89. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  90. )
  91. self.session = scoped_session(SessionLocal)
  92. try:
  93. # Ensure the pgvector extension is available
  94. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  95. if PGVECTOR_PGCRYPTO:
  96. # Ensure the pgcrypto extension is available for encryption
  97. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
  98. if not PGVECTOR_PGCRYPTO_KEY:
  99. raise ValueError(
  100. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  101. )
  102. # Check vector length consistency
  103. self.check_vector_length()
  104. # Create the tables if they do not exist
  105. # Base.metadata.create_all requires a bind (engine or connection)
  106. # Get the connection from the session
  107. connection = self.session.connection()
  108. Base.metadata.create_all(bind=connection)
  109. # Create an index on the vector column if it doesn't exist
  110. self.session.execute(
  111. text(
  112. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  113. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  114. )
  115. )
  116. self.session.execute(
  117. text(
  118. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  119. "ON document_chunk (collection_name);"
  120. )
  121. )
  122. self.session.commit()
  123. log.info("Initialization complete.")
  124. except Exception as e:
  125. self.session.rollback()
  126. log.exception(f"Error during initialization: {e}")
  127. raise
  128. def check_vector_length(self) -> None:
  129. """
  130. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  131. Raises an exception if there is a mismatch.
  132. """
  133. metadata = MetaData()
  134. try:
  135. # Attempt to reflect the 'document_chunk' table
  136. document_chunk_table = Table(
  137. "document_chunk", metadata, autoload_with=self.session.bind
  138. )
  139. except NoSuchTableError:
  140. # Table does not exist; no action needed
  141. return
  142. # Proceed to check the vector column
  143. if "vector" in document_chunk_table.columns:
  144. vector_column = document_chunk_table.columns["vector"]
  145. vector_type = vector_column.type
  146. if isinstance(vector_type, Vector):
  147. db_vector_length = vector_type.dim
  148. if db_vector_length != VECTOR_LENGTH:
  149. raise Exception(
  150. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  151. "Cannot change vector size after initialization without migrating the data."
  152. )
  153. else:
  154. raise Exception(
  155. "The 'vector' column exists but is not of type 'Vector'."
  156. )
  157. else:
  158. raise Exception(
  159. "The 'vector' column does not exist in the 'document_chunk' table."
  160. )
  161. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  162. # Adjust vector to have length VECTOR_LENGTH
  163. current_length = len(vector)
  164. if current_length < VECTOR_LENGTH:
  165. # Pad the vector with zeros
  166. vector += [0.0] * (VECTOR_LENGTH - current_length)
  167. elif current_length > VECTOR_LENGTH:
  168. # Truncate the vector to VECTOR_LENGTH
  169. vector = vector[:VECTOR_LENGTH]
  170. return vector
  171. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  172. try:
  173. if PGVECTOR_PGCRYPTO:
  174. for item in items:
  175. vector = self.adjust_vector_length(item["vector"])
  176. # Use raw SQL for BYTEA/pgcrypto
  177. # Ensure metadata is converted to its JSON text representation
  178. json_metadata = json.dumps(item["metadata"])
  179. self.session.execute(
  180. text(
  181. """
  182. INSERT INTO document_chunk
  183. (id, vector, collection_name, text, vmetadata)
  184. VALUES (
  185. :id, :vector, :collection_name,
  186. pgp_sym_encrypt(:text, :key),
  187. pgp_sym_encrypt(:metadata_text, :key)
  188. )
  189. ON CONFLICT (id) DO NOTHING
  190. """
  191. ),
  192. {
  193. "id": item["id"],
  194. "vector": vector,
  195. "collection_name": collection_name,
  196. "text": item["text"],
  197. "metadata_text": json_metadata,
  198. "key": PGVECTOR_PGCRYPTO_KEY,
  199. },
  200. )
  201. self.session.commit()
  202. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  203. else:
  204. new_items = []
  205. for item in items:
  206. vector = self.adjust_vector_length(item["vector"])
  207. new_chunk = DocumentChunk(
  208. id=item["id"],
  209. vector=vector,
  210. collection_name=collection_name,
  211. text=item["text"],
  212. vmetadata=stringify_metadata(item["metadata"]),
  213. )
  214. new_items.append(new_chunk)
  215. self.session.bulk_save_objects(new_items)
  216. self.session.commit()
  217. log.info(
  218. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  219. )
  220. except Exception as e:
  221. self.session.rollback()
  222. log.exception(f"Error during insert: {e}")
  223. raise
  224. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  225. try:
  226. if PGVECTOR_PGCRYPTO:
  227. for item in items:
  228. vector = self.adjust_vector_length(item["vector"])
  229. json_metadata = json.dumps(item["metadata"])
  230. self.session.execute(
  231. text(
  232. """
  233. INSERT INTO document_chunk
  234. (id, vector, collection_name, text, vmetadata)
  235. VALUES (
  236. :id, :vector, :collection_name,
  237. pgp_sym_encrypt(:text, :key),
  238. pgp_sym_encrypt(:metadata_text, :key)
  239. )
  240. ON CONFLICT (id) DO UPDATE SET
  241. vector = EXCLUDED.vector,
  242. collection_name = EXCLUDED.collection_name,
  243. text = EXCLUDED.text,
  244. vmetadata = EXCLUDED.vmetadata
  245. """
  246. ),
  247. {
  248. "id": item["id"],
  249. "vector": vector,
  250. "collection_name": collection_name,
  251. "text": item["text"],
  252. "metadata_text": json_metadata,
  253. "key": PGVECTOR_PGCRYPTO_KEY,
  254. },
  255. )
  256. self.session.commit()
  257. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  258. else:
  259. for item in items:
  260. vector = self.adjust_vector_length(item["vector"])
  261. existing = (
  262. self.session.query(DocumentChunk)
  263. .filter(DocumentChunk.id == item["id"])
  264. .first()
  265. )
  266. if existing:
  267. existing.vector = vector
  268. existing.text = item["text"]
  269. existing.vmetadata = stringify_metadata(item["metadata"])
  270. existing.collection_name = (
  271. collection_name # Update collection_name if necessary
  272. )
  273. else:
  274. new_chunk = DocumentChunk(
  275. id=item["id"],
  276. vector=vector,
  277. collection_name=collection_name,
  278. text=item["text"],
  279. vmetadata=stringify_metadata(item["metadata"]),
  280. )
  281. self.session.add(new_chunk)
  282. self.session.commit()
  283. log.info(
  284. f"Upserted {len(items)} items into collection '{collection_name}'."
  285. )
  286. except Exception as e:
  287. self.session.rollback()
  288. log.exception(f"Error during upsert: {e}")
  289. raise
  290. def search(
  291. self,
  292. collection_name: str,
  293. vectors: List[List[float]],
  294. limit: Optional[int] = None,
  295. ) -> Optional[SearchResult]:
  296. try:
  297. if not vectors:
  298. return None
  299. # Adjust query vectors to VECTOR_LENGTH
  300. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  301. num_queries = len(vectors)
  302. def vector_expr(vector):
  303. return cast(array(vector), Vector(VECTOR_LENGTH))
  304. # Create the values for query vectors
  305. qid_col = column("qid", Integer)
  306. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  307. query_vectors = (
  308. values(qid_col, q_vector_col)
  309. .data(
  310. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  311. )
  312. .alias("query_vectors")
  313. )
  314. result_fields = [
  315. DocumentChunk.id,
  316. ]
  317. if PGVECTOR_PGCRYPTO:
  318. result_fields.append(
  319. pgcrypto_decrypt(
  320. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  321. ).label("text")
  322. )
  323. result_fields.append(
  324. pgcrypto_decrypt(
  325. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  326. ).label("vmetadata")
  327. )
  328. else:
  329. result_fields.append(DocumentChunk.text)
  330. result_fields.append(DocumentChunk.vmetadata)
  331. result_fields.append(
  332. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  333. "distance"
  334. )
  335. )
  336. # Build the lateral subquery for each query vector
  337. subq = (
  338. select(*result_fields)
  339. .where(DocumentChunk.collection_name == collection_name)
  340. .order_by(
  341. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  342. )
  343. )
  344. if limit is not None:
  345. subq = subq.limit(limit)
  346. subq = subq.lateral("result")
  347. # Build the main query by joining query_vectors and the lateral subquery
  348. stmt = (
  349. select(
  350. query_vectors.c.qid,
  351. subq.c.id,
  352. subq.c.text,
  353. subq.c.vmetadata,
  354. subq.c.distance,
  355. )
  356. .select_from(query_vectors)
  357. .join(subq, true())
  358. .order_by(query_vectors.c.qid, subq.c.distance)
  359. )
  360. result_proxy = self.session.execute(stmt)
  361. results = result_proxy.all()
  362. ids = [[] for _ in range(num_queries)]
  363. distances = [[] for _ in range(num_queries)]
  364. documents = [[] for _ in range(num_queries)]
  365. metadatas = [[] for _ in range(num_queries)]
  366. if not results:
  367. return SearchResult(
  368. ids=ids,
  369. distances=distances,
  370. documents=documents,
  371. metadatas=metadatas,
  372. )
  373. for row in results:
  374. qid = int(row.qid)
  375. ids[qid].append(row.id)
  376. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  377. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  378. distances[qid].append((2.0 - row.distance) / 2.0)
  379. documents[qid].append(row.text)
  380. metadatas[qid].append(row.vmetadata)
  381. self.session.rollback() # read-only transaction
  382. return SearchResult(
  383. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  384. )
  385. except Exception as e:
  386. self.session.rollback()
  387. log.exception(f"Error during search: {e}")
  388. return None
  389. def query(
  390. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  391. ) -> Optional[GetResult]:
  392. try:
  393. if PGVECTOR_PGCRYPTO:
  394. # Build where clause for vmetadata filter
  395. where_clauses = [DocumentChunk.collection_name == collection_name]
  396. for key, value in filter.items():
  397. # decrypt then check key: JSON filter after decryption
  398. where_clauses.append(
  399. pgcrypto_decrypt(
  400. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  401. )[key].astext
  402. == str(value)
  403. )
  404. stmt = select(
  405. DocumentChunk.id,
  406. pgcrypto_decrypt(
  407. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  408. ).label("text"),
  409. pgcrypto_decrypt(
  410. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  411. ).label("vmetadata"),
  412. ).where(*where_clauses)
  413. if limit is not None:
  414. stmt = stmt.limit(limit)
  415. results = self.session.execute(stmt).all()
  416. else:
  417. query = self.session.query(DocumentChunk).filter(
  418. DocumentChunk.collection_name == collection_name
  419. )
  420. for key, value in filter.items():
  421. query = query.filter(
  422. DocumentChunk.vmetadata[key].astext == str(value)
  423. )
  424. if limit is not None:
  425. query = query.limit(limit)
  426. results = query.all()
  427. if not results:
  428. return None
  429. ids = [[result.id for result in results]]
  430. documents = [[result.text for result in results]]
  431. metadatas = [[result.vmetadata for result in results]]
  432. self.session.rollback() # read-only transaction
  433. return GetResult(
  434. ids=ids,
  435. documents=documents,
  436. metadatas=metadatas,
  437. )
  438. except Exception as e:
  439. self.session.rollback()
  440. log.exception(f"Error during query: {e}")
  441. return None
  442. def get(
  443. self, collection_name: str, limit: Optional[int] = None
  444. ) -> Optional[GetResult]:
  445. try:
  446. if PGVECTOR_PGCRYPTO:
  447. stmt = select(
  448. DocumentChunk.id,
  449. pgcrypto_decrypt(
  450. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  451. ).label("text"),
  452. pgcrypto_decrypt(
  453. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  454. ).label("vmetadata"),
  455. ).where(DocumentChunk.collection_name == collection_name)
  456. if limit is not None:
  457. stmt = stmt.limit(limit)
  458. results = self.session.execute(stmt).all()
  459. ids = [[row.id for row in results]]
  460. documents = [[row.text for row in results]]
  461. metadatas = [[row.vmetadata for row in results]]
  462. else:
  463. query = self.session.query(DocumentChunk).filter(
  464. DocumentChunk.collection_name == collection_name
  465. )
  466. if limit is not None:
  467. query = query.limit(limit)
  468. results = query.all()
  469. if not results:
  470. return None
  471. ids = [[result.id for result in results]]
  472. documents = [[result.text for result in results]]
  473. metadatas = [[result.vmetadata for result in results]]
  474. self.session.rollback() # read-only transaction
  475. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  476. except Exception as e:
  477. self.session.rollback()
  478. log.exception(f"Error during get: {e}")
  479. return None
  480. def delete(
  481. self,
  482. collection_name: str,
  483. ids: Optional[List[str]] = None,
  484. filter: Optional[Dict[str, Any]] = None,
  485. ) -> None:
  486. try:
  487. if PGVECTOR_PGCRYPTO:
  488. wheres = [DocumentChunk.collection_name == collection_name]
  489. if ids:
  490. wheres.append(DocumentChunk.id.in_(ids))
  491. if filter:
  492. for key, value in filter.items():
  493. wheres.append(
  494. pgcrypto_decrypt(
  495. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  496. )[key].astext
  497. == str(value)
  498. )
  499. stmt = DocumentChunk.__table__.delete().where(*wheres)
  500. result = self.session.execute(stmt)
  501. deleted = result.rowcount
  502. else:
  503. query = self.session.query(DocumentChunk).filter(
  504. DocumentChunk.collection_name == collection_name
  505. )
  506. if ids:
  507. query = query.filter(DocumentChunk.id.in_(ids))
  508. if filter:
  509. for key, value in filter.items():
  510. query = query.filter(
  511. DocumentChunk.vmetadata[key].astext == str(value)
  512. )
  513. deleted = query.delete(synchronize_session=False)
  514. self.session.commit()
  515. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  516. except Exception as e:
  517. self.session.rollback()
  518. log.exception(f"Error during delete: {e}")
  519. raise
  520. def reset(self) -> None:
  521. try:
  522. deleted = self.session.query(DocumentChunk).delete()
  523. self.session.commit()
  524. log.info(
  525. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  526. )
  527. except Exception as e:
  528. self.session.rollback()
  529. log.exception(f"Error during reset: {e}")
  530. raise
  531. def close(self) -> None:
  532. pass
  533. def has_collection(self, collection_name: str) -> bool:
  534. try:
  535. exists = (
  536. self.session.query(DocumentChunk)
  537. .filter(DocumentChunk.collection_name == collection_name)
  538. .first()
  539. is not None
  540. )
  541. self.session.rollback() # read-only transaction
  542. return exists
  543. except Exception as e:
  544. self.session.rollback()
  545. log.exception(f"Error checking collection existence: {e}")
  546. return False
  547. def delete_collection(self, collection_name: str) -> None:
  548. self.delete(collection_name)
  549. log.info(f"Collection '{collection_name}' deleted.")