pgvector.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. from typing import Optional, List, Dict, Any
  2. import logging
  3. import json
  4. from sqlalchemy import (
  5. func,
  6. literal,
  7. cast,
  8. column,
  9. create_engine,
  10. Column,
  11. Integer,
  12. MetaData,
  13. LargeBinary,
  14. select,
  15. text,
  16. Text,
  17. Table,
  18. values,
  19. )
  20. from sqlalchemy.sql import true
  21. from sqlalchemy.pool import NullPool
  22. from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
  23. from sqlalchemy.dialects.postgresql import JSONB, array
  24. from pgvector.sqlalchemy import Vector
  25. from sqlalchemy.ext.mutable import MutableDict
  26. from sqlalchemy.exc import NoSuchTableError
  27. from open_webui.retrieval.vector.main import (
  28. VectorDBBase,
  29. VectorItem,
  30. SearchResult,
  31. GetResult,
  32. )
  33. from open_webui.config import (
  34. PGVECTOR_DB_URL,
  35. PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
  36. PGVECTOR_PGCRYPTO,
  37. PGVECTOR_PGCRYPTO_KEY,
  38. )
  39. from open_webui.env import SRC_LOG_LEVELS
  40. VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
  41. Base = declarative_base()
  42. log = logging.getLogger(__name__)
  43. log.setLevel(SRC_LOG_LEVELS["RAG"])
  44. def pgcrypto_encrypt(val, key):
  45. return func.pgp_sym_encrypt(val, literal(key))
  46. def pgcrypto_decrypt(col, key, outtype="text"):
  47. return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
  48. class DocumentChunk(Base):
  49. __tablename__ = "document_chunk"
  50. id = Column(Text, primary_key=True)
  51. vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
  52. collection_name = Column(Text, nullable=False)
  53. if PGVECTOR_PGCRYPTO:
  54. text = Column(LargeBinary, nullable=True)
  55. vmetadata = Column(LargeBinary, nullable=True)
  56. else:
  57. text = Column(Text, nullable=True)
  58. vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
  59. class PgvectorClient(VectorDBBase):
  60. def __init__(self) -> None:
  61. # if no pgvector uri, use the existing database connection
  62. if not PGVECTOR_DB_URL:
  63. from open_webui.internal.db import Session
  64. self.session = Session
  65. else:
  66. engine = create_engine(
  67. PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
  68. )
  69. SessionLocal = sessionmaker(
  70. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  71. )
  72. self.session = scoped_session(SessionLocal)
  73. try:
  74. # Ensure the pgvector extension is available
  75. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
  76. if PGVECTOR_PGCRYPTO:
  77. # Ensure the pgcrypto extension is available for encryption
  78. self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
  79. if not PGVECTOR_PGCRYPTO_KEY:
  80. raise ValueError(
  81. "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
  82. )
  83. # Check vector length consistency
  84. self.check_vector_length()
  85. # Create the tables if they do not exist
  86. # Base.metadata.create_all requires a bind (engine or connection)
  87. # Get the connection from the session
  88. connection = self.session.connection()
  89. Base.metadata.create_all(bind=connection)
  90. # Create an index on the vector column if it doesn't exist
  91. self.session.execute(
  92. text(
  93. "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
  94. "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
  95. )
  96. )
  97. self.session.execute(
  98. text(
  99. "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
  100. "ON document_chunk (collection_name);"
  101. )
  102. )
  103. self.session.commit()
  104. log.info("Initialization complete.")
  105. except Exception as e:
  106. self.session.rollback()
  107. log.exception(f"Error during initialization: {e}")
  108. raise
  109. def check_vector_length(self) -> None:
  110. """
  111. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
  112. Raises an exception if there is a mismatch.
  113. """
  114. metadata = MetaData()
  115. try:
  116. # Attempt to reflect the 'document_chunk' table
  117. document_chunk_table = Table(
  118. "document_chunk", metadata, autoload_with=self.session.bind
  119. )
  120. except NoSuchTableError:
  121. # Table does not exist; no action needed
  122. return
  123. # Proceed to check the vector column
  124. if "vector" in document_chunk_table.columns:
  125. vector_column = document_chunk_table.columns["vector"]
  126. vector_type = vector_column.type
  127. if isinstance(vector_type, Vector):
  128. db_vector_length = vector_type.dim
  129. if db_vector_length != VECTOR_LENGTH:
  130. raise Exception(
  131. f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
  132. "Cannot change vector size after initialization without migrating the data."
  133. )
  134. else:
  135. raise Exception(
  136. "The 'vector' column exists but is not of type 'Vector'."
  137. )
  138. else:
  139. raise Exception(
  140. "The 'vector' column does not exist in the 'document_chunk' table."
  141. )
  142. def adjust_vector_length(self, vector: List[float]) -> List[float]:
  143. # Adjust vector to have length VECTOR_LENGTH
  144. current_length = len(vector)
  145. if current_length < VECTOR_LENGTH:
  146. # Pad the vector with zeros
  147. vector += [0.0] * (VECTOR_LENGTH - current_length)
  148. elif current_length > VECTOR_LENGTH:
  149. # Truncate the vector to VECTOR_LENGTH
  150. vector = vector[:VECTOR_LENGTH]
  151. return vector
  152. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  153. try:
  154. if PGVECTOR_PGCRYPTO:
  155. for item in items:
  156. vector = self.adjust_vector_length(item["vector"])
  157. # Use raw SQL for BYTEA/pgcrypto
  158. self.session.execute(
  159. text(
  160. """
  161. INSERT INTO document_chunk
  162. (id, vector, collection_name, text, vmetadata)
  163. VALUES (
  164. :id, :vector, :collection_name,
  165. pgp_sym_encrypt(:text, :key),
  166. pgp_sym_encrypt(:metadata::text, :key)
  167. )
  168. ON CONFLICT (id) DO NOTHING
  169. """
  170. ),
  171. {
  172. "id": item["id"],
  173. "vector": vector,
  174. "collection_name": collection_name,
  175. "text": item["text"],
  176. "metadata": json.dumps(item["metadata"]),
  177. "key": PGVECTOR_PGCRYPTO_KEY,
  178. },
  179. )
  180. self.session.commit()
  181. log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
  182. else:
  183. new_items = []
  184. for item in items:
  185. vector = self.adjust_vector_length(item["vector"])
  186. new_chunk = DocumentChunk(
  187. id=item["id"],
  188. vector=vector,
  189. collection_name=collection_name,
  190. text=item["text"],
  191. vmetadata=item["metadata"],
  192. )
  193. new_items.append(new_chunk)
  194. self.session.bulk_save_objects(new_items)
  195. self.session.commit()
  196. log.info(
  197. f"Inserted {len(new_items)} items into collection '{collection_name}'."
  198. )
  199. except Exception as e:
  200. self.session.rollback()
  201. log.exception(f"Error during insert: {e}")
  202. raise
  203. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  204. try:
  205. if PGVECTOR_PGCRYPTO:
  206. for item in items:
  207. vector = self.adjust_vector_length(item["vector"])
  208. self.session.execute(
  209. text(
  210. """
  211. INSERT INTO document_chunk
  212. (id, vector, collection_name, text, vmetadata)
  213. VALUES (
  214. :id, :vector, :collection_name,
  215. pgp_sym_encrypt(:text, :key),
  216. pgp_sym_encrypt(:metadata::text, :key)
  217. )
  218. ON CONFLICT (id) DO UPDATE SET
  219. vector = EXCLUDED.vector,
  220. collection_name = EXCLUDED.collection_name,
  221. text = EXCLUDED.text,
  222. vmetadata = EXCLUDED.vmetadata
  223. """
  224. ),
  225. {
  226. "id": item["id"],
  227. "vector": vector,
  228. "collection_name": collection_name,
  229. "text": item["text"],
  230. "metadata": json.dumps(item["metadata"]),
  231. "key": PGVECTOR_PGCRYPTO_KEY,
  232. },
  233. )
  234. self.session.commit()
  235. log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
  236. else:
  237. for item in items:
  238. vector = self.adjust_vector_length(item["vector"])
  239. existing = (
  240. self.session.query(DocumentChunk)
  241. .filter(DocumentChunk.id == item["id"])
  242. .first()
  243. )
  244. if existing:
  245. existing.vector = vector
  246. existing.text = item["text"]
  247. existing.vmetadata = item["metadata"]
  248. existing.collection_name = (
  249. collection_name # Update collection_name if necessary
  250. )
  251. else:
  252. new_chunk = DocumentChunk(
  253. id=item["id"],
  254. vector=vector,
  255. collection_name=collection_name,
  256. text=item["text"],
  257. vmetadata=item["metadata"],
  258. )
  259. self.session.add(new_chunk)
  260. self.session.commit()
  261. log.info(
  262. f"Upserted {len(items)} items into collection '{collection_name}'."
  263. )
  264. except Exception as e:
  265. self.session.rollback()
  266. log.exception(f"Error during upsert: {e}")
  267. raise
  268. def search(
  269. self,
  270. collection_name: str,
  271. vectors: List[List[float]],
  272. limit: Optional[int] = None,
  273. ) -> Optional[SearchResult]:
  274. try:
  275. if not vectors:
  276. return None
  277. # Adjust query vectors to VECTOR_LENGTH
  278. vectors = [self.adjust_vector_length(vector) for vector in vectors]
  279. num_queries = len(vectors)
  280. def vector_expr(vector):
  281. return cast(array(vector), Vector(VECTOR_LENGTH))
  282. # Create the values for query vectors
  283. qid_col = column("qid", Integer)
  284. q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
  285. query_vectors = (
  286. values(qid_col, q_vector_col)
  287. .data(
  288. [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
  289. )
  290. .alias("query_vectors")
  291. )
  292. result_fields = [
  293. DocumentChunk.id,
  294. ]
  295. if PGVECTOR_PGCRYPTO:
  296. result_fields.append(
  297. pgcrypto_decrypt(
  298. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  299. ).label("text")
  300. )
  301. result_fields.append(
  302. pgcrypto_decrypt(
  303. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  304. ).label("vmetadata")
  305. )
  306. else:
  307. result_fields.append(DocumentChunk.text)
  308. result_fields.append(DocumentChunk.vmetadata)
  309. result_fields.append(
  310. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
  311. "distance"
  312. )
  313. )
  314. # Build the lateral subquery for each query vector
  315. subq = (
  316. select(*result_fields)
  317. .where(DocumentChunk.collection_name == collection_name)
  318. .order_by(
  319. (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
  320. )
  321. )
  322. if limit is not None:
  323. subq = subq.limit(limit)
  324. subq = subq.lateral("result")
  325. # Build the main query by joining query_vectors and the lateral subquery
  326. stmt = (
  327. select(
  328. query_vectors.c.qid,
  329. subq.c.id,
  330. subq.c.text,
  331. subq.c.vmetadata,
  332. subq.c.distance,
  333. )
  334. .select_from(query_vectors)
  335. .join(subq, true())
  336. .order_by(query_vectors.c.qid, subq.c.distance)
  337. )
  338. result_proxy = self.session.execute(stmt)
  339. results = result_proxy.all()
  340. ids = [[] for _ in range(num_queries)]
  341. distances = [[] for _ in range(num_queries)]
  342. documents = [[] for _ in range(num_queries)]
  343. metadatas = [[] for _ in range(num_queries)]
  344. if not results:
  345. return SearchResult(
  346. ids=ids,
  347. distances=distances,
  348. documents=documents,
  349. metadatas=metadatas,
  350. )
  351. for row in results:
  352. qid = int(row.qid)
  353. ids[qid].append(row.id)
  354. # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
  355. # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
  356. distances[qid].append((2.0 - row.distance) / 2.0)
  357. documents[qid].append(row.text)
  358. metadatas[qid].append(row.vmetadata)
  359. return SearchResult(
  360. ids=ids, distances=distances, documents=documents, metadatas=metadatas
  361. )
  362. except Exception as e:
  363. log.exception(f"Error during search: {e}")
  364. return None
  365. def query(
  366. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  367. ) -> Optional[GetResult]:
  368. try:
  369. if PGVECTOR_PGCRYPTO:
  370. # Build where clause for vmetadata filter
  371. where_clauses = [DocumentChunk.collection_name == collection_name]
  372. for key, value in filter.items():
  373. # decrypt then check key: JSON filter after decryption
  374. where_clauses.append(
  375. pgcrypto_decrypt(
  376. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  377. )[key].astext
  378. == str(value)
  379. )
  380. stmt = select(
  381. DocumentChunk.id,
  382. pgcrypto_decrypt(
  383. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  384. ).label("text"),
  385. pgcrypto_decrypt(
  386. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  387. ).label("vmetadata"),
  388. ).where(*where_clauses)
  389. if limit is not None:
  390. stmt = stmt.limit(limit)
  391. results = self.session.execute(stmt).all()
  392. else:
  393. query = self.session.query(DocumentChunk).filter(
  394. DocumentChunk.collection_name == collection_name
  395. )
  396. for key, value in filter.items():
  397. query = query.filter(
  398. DocumentChunk.vmetadata[key].astext == str(value)
  399. )
  400. if limit is not None:
  401. query = query.limit(limit)
  402. results = query.all()
  403. if not results:
  404. return None
  405. ids = [[result.id for result in results]]
  406. documents = [[result.text for result in results]]
  407. metadatas = [[result.vmetadata for result in results]]
  408. return GetResult(
  409. ids=ids,
  410. documents=documents,
  411. metadatas=metadatas,
  412. )
  413. except Exception as e:
  414. log.exception(f"Error during query: {e}")
  415. return None
  416. def get(
  417. self, collection_name: str, limit: Optional[int] = None
  418. ) -> Optional[GetResult]:
  419. try:
  420. if PGVECTOR_PGCRYPTO:
  421. stmt = select(
  422. DocumentChunk.id,
  423. pgcrypto_decrypt(
  424. DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
  425. ).label("text"),
  426. pgcrypto_decrypt(
  427. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  428. ).label("vmetadata"),
  429. ).where(DocumentChunk.collection_name == collection_name)
  430. if limit is not None:
  431. stmt = stmt.limit(limit)
  432. results = self.session.execute(stmt).all()
  433. ids = [[row.id for row in results]]
  434. documents = [[row.text for row in results]]
  435. metadatas = [[row.vmetadata for row in results]]
  436. else:
  437. query = self.session.query(DocumentChunk).filter(
  438. DocumentChunk.collection_name == collection_name
  439. )
  440. if limit is not None:
  441. query = query.limit(limit)
  442. results = query.all()
  443. if not results:
  444. return None
  445. ids = [[result.id for result in results]]
  446. documents = [[result.text for result in results]]
  447. metadatas = [[result.vmetadata for result in results]]
  448. return GetResult(ids=ids, documents=documents, metadatas=metadatas)
  449. except Exception as e:
  450. log.exception(f"Error during get: {e}")
  451. return None
  452. def delete(
  453. self,
  454. collection_name: str,
  455. ids: Optional[List[str]] = None,
  456. filter: Optional[Dict[str, Any]] = None,
  457. ) -> None:
  458. try:
  459. if PGVECTOR_PGCRYPTO:
  460. wheres = [DocumentChunk.collection_name == collection_name]
  461. if ids:
  462. wheres.append(DocumentChunk.id.in_(ids))
  463. if filter:
  464. for key, value in filter.items():
  465. wheres.append(
  466. pgcrypto_decrypt(
  467. DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
  468. )[key].astext
  469. == str(value)
  470. )
  471. stmt = DocumentChunk.__table__.delete().where(*wheres)
  472. result = self.session.execute(stmt)
  473. deleted = result.rowcount
  474. else:
  475. query = self.session.query(DocumentChunk).filter(
  476. DocumentChunk.collection_name == collection_name
  477. )
  478. if ids:
  479. query = query.filter(DocumentChunk.id.in_(ids))
  480. if filter:
  481. for key, value in filter.items():
  482. query = query.filter(
  483. DocumentChunk.vmetadata[key].astext == str(value)
  484. )
  485. deleted = query.delete(synchronize_session=False)
  486. self.session.commit()
  487. log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
  488. except Exception as e:
  489. self.session.rollback()
  490. log.exception(f"Error during delete: {e}")
  491. raise
  492. def reset(self) -> None:
  493. try:
  494. deleted = self.session.query(DocumentChunk).delete()
  495. self.session.commit()
  496. log.info(
  497. f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
  498. )
  499. except Exception as e:
  500. self.session.rollback()
  501. log.exception(f"Error during reset: {e}")
  502. raise
  503. def close(self) -> None:
  504. pass
  505. def has_collection(self, collection_name: str) -> bool:
  506. try:
  507. exists = (
  508. self.session.query(DocumentChunk)
  509. .filter(DocumentChunk.collection_name == collection_name)
  510. .first()
  511. is not None
  512. )
  513. return exists
  514. except Exception as e:
  515. log.exception(f"Error checking collection existence: {e}")
  516. return False
  517. def delete_collection(self, collection_name: str) -> None:
  518. self.delete(collection_name)
  519. log.info(f"Collection '{collection_name}' deleted.")