瀏覽代碼

feat: experimental pgvector pgcrypto support

Timothy Jaeryang Baek 4 月之前
父節點
當前提交
7f488b3754
共有 2 個文件被更改,包括 248 次插入81 次删除
  1. 7 0
      backend/open_webui/config.py
  2. 241 81
      backend/open_webui/retrieval/vector/dbs/pgvector.py

+ 7 - 0
backend/open_webui/config.py

@@ -1825,6 +1825,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
     os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
 )
 
+PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
+PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
+if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
+    raise ValueError(
+        "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
+    )
+
 # Pinecone
 PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
 PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)

+ 241 - 81
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -1,12 +1,16 @@
 from typing import Optional, List, Dict, Any
 import logging
+import json
 from sqlalchemy import (
+    func,
+    literal,
     cast,
     column,
     create_engine,
     Column,
     Integer,
     MetaData,
+    LargeBinary,
     select,
     text,
     Text,
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
     SearchResult,
     GetResult,
 )
-from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
+from open_webui.config import (
+    PGVECTOR_DB_URL,
+    PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
+    PGVECTOR_PGCRYPTO,
+    PGVECTOR_PGCRYPTO_KEY,
+)
 
 from open_webui.env import SRC_LOG_LEVELS
 
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
+def pgcrypto_encrypt(val, key):
+    return func.pgp_sym_encrypt(val, literal(key))
+
+
+def pgcrypto_decrypt(col, key, outtype="text"):
+    return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
+
+
 class DocumentChunk(Base):
     __tablename__ = "document_chunk"
 
     id = Column(Text, primary_key=True)
     vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
     collection_name = Column(Text, nullable=False)
-    text = Column(Text, nullable=True)
-    vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
+
+    if PGVECTOR_PGCRYPTO:
+        text = Column(LargeBinary, nullable=True)
+        vmetadata = Column(LargeBinary, nullable=True)
+    else:
+        text = Column(Text, nullable=True)
+        vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
 
 
 class PgvectorClient(VectorDBBase):
@@ -147,22 +169,52 @@ class PgvectorClient(VectorDBBase):
 
     def insert(self, collection_name: str, items: List[VectorItem]) -> None:
         try:
-            new_items = []
-            for item in items:
-                vector = self.adjust_vector_length(item["vector"])
-                new_chunk = DocumentChunk(
-                    id=item["id"],
-                    vector=vector,
-                    collection_name=collection_name,
-                    text=item["text"],
-                    vmetadata=item["metadata"],
+            if PGVECTOR_PGCRYPTO:
+                for item in items:
+                    vector = self.adjust_vector_length(item["vector"])
+                    # Use raw SQL for BYTEA/pgcrypto
+                    self.session.execute(
+                        text(
+                            """
+                            INSERT INTO document_chunk
+                            (id, vector, collection_name, text, vmetadata)
+                            VALUES (
+                                :id, :vector, :collection_name,
+                                pgp_sym_encrypt(:text, :key),
+                                pgp_sym_encrypt(:metadata::text, :key)
+                            )
+                            ON CONFLICT (id) DO NOTHING
+                        """
+                        ),
+                        {
+                            "id": item["id"],
+                            "vector": vector,
+                            "collection_name": collection_name,
+                            "text": item["text"],
+                            "metadata": json.dumps(item["metadata"]),
+                            "key": PGVECTOR_PGCRYPTO_KEY,
+                        },
+                    )
+                self.session.commit()
+                log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
+
+            else:
+                new_items = []
+                for item in items:
+                    vector = self.adjust_vector_length(item["vector"])
+                    new_chunk = DocumentChunk(
+                        id=item["id"],
+                        vector=vector,
+                        collection_name=collection_name,
+                        text=item["text"],
+                        vmetadata=item["metadata"],
+                    )
+                    new_items.append(new_chunk)
+                self.session.bulk_save_objects(new_items)
+                self.session.commit()
+                log.info(
+                    f"Inserted {len(new_items)} items into collection '{collection_name}'."
                 )
-                new_items.append(new_chunk)
-            self.session.bulk_save_objects(new_items)
-            self.session.commit()
-            log.info(
-                f"Inserted {len(new_items)} items into collection '{collection_name}'."
-            )
         except Exception as e:
             self.session.rollback()
             log.exception(f"Error during insert: {e}")
@@ -170,33 +222,65 @@ class PgvectorClient(VectorDBBase):
 
     def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
         try:
-            for item in items:
-                vector = self.adjust_vector_length(item["vector"])
-                existing = (
-                    self.session.query(DocumentChunk)
-                    .filter(DocumentChunk.id == item["id"])
-                    .first()
-                )
-                if existing:
-                    existing.vector = vector
-                    existing.text = item["text"]
-                    existing.vmetadata = item["metadata"]
-                    existing.collection_name = (
-                        collection_name  # Update collection_name if necessary
+            if PGVECTOR_PGCRYPTO:
+                for item in items:
+                    vector = self.adjust_vector_length(item["vector"])
+                    self.session.execute(
+                        text(
+                            """
+                            INSERT INTO document_chunk
+                            (id, vector, collection_name, text, vmetadata)
+                            VALUES (
+                                :id, :vector, :collection_name,
+                                pgp_sym_encrypt(:text, :key),
+                                pgp_sym_encrypt(:metadata::text, :key)
+                            )
+                            ON CONFLICT (id) DO UPDATE SET
+                              vector = EXCLUDED.vector,
+                              collection_name = EXCLUDED.collection_name,
+                              text = EXCLUDED.text,
+                              vmetadata = EXCLUDED.vmetadata
+                        """
+                        ),
+                        {
+                            "id": item["id"],
+                            "vector": vector,
+                            "collection_name": collection_name,
+                            "text": item["text"],
+                            "metadata": json.dumps(item["metadata"]),
+                            "key": PGVECTOR_PGCRYPTO_KEY,
+                        },
                     )
-                else:
-                    new_chunk = DocumentChunk(
-                        id=item["id"],
-                        vector=vector,
-                        collection_name=collection_name,
-                        text=item["text"],
-                        vmetadata=item["metadata"],
+                self.session.commit()
+                log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
+            else:
+                for item in items:
+                    vector = self.adjust_vector_length(item["vector"])
+                    existing = (
+                        self.session.query(DocumentChunk)
+                        .filter(DocumentChunk.id == item["id"])
+                        .first()
                     )
-                    self.session.add(new_chunk)
-            self.session.commit()
-            log.info(
-                f"Upserted {len(items)} items into collection '{collection_name}'."
-            )
+                    if existing:
+                        existing.vector = vector
+                        existing.text = item["text"]
+                        existing.vmetadata = item["metadata"]
+                        existing.collection_name = (
+                            collection_name  # Update collection_name if necessary
+                        )
+                    else:
+                        new_chunk = DocumentChunk(
+                            id=item["id"],
+                            vector=vector,
+                            collection_name=collection_name,
+                            text=item["text"],
+                            vmetadata=item["metadata"],
+                        )
+                        self.session.add(new_chunk)
+                self.session.commit()
+                log.info(
+                    f"Upserted {len(items)} items into collection '{collection_name}'."
+                )
         except Exception as e:
             self.session.rollback()
             log.exception(f"Error during upsert: {e}")
@@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
                 .alias("query_vectors")
             )
 
+            result_fields = [
+                DocumentChunk.id,
+            ]
+            if PGVECTOR_PGCRYPTO:
+                result_fields.append(
+                    pgcrypto_decrypt(
+                        DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
+                    ).label("text")
+                )
+                result_fields.append(
+                    pgcrypto_decrypt(
+                        DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
+                    ).label("vmetadata")
+                )
+            else:
+                result_fields.append(DocumentChunk.text)
+                result_fields.append(DocumentChunk.vmetadata)
+            result_fields.append(
+                (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
+                    "distance"
+                )
+            )
+
             # Build the lateral subquery for each query vector
             subq = (
-                select(
-                    DocumentChunk.id,
-                    DocumentChunk.text,
-                    DocumentChunk.vmetadata,
-                    (
-                        DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
-                    ).label("distance"),
-                )
+                select(*result_fields)
                 .where(DocumentChunk.collection_name == collection_name)
                 .order_by(
                     (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase):
         self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
     ) -> Optional[GetResult]:
         try:
-            query = self.session.query(DocumentChunk).filter(
-                DocumentChunk.collection_name == collection_name
-            )
+            if PGVECTOR_PGCRYPTO:
+                # Build where clause for vmetadata filter
+                where_clauses = [DocumentChunk.collection_name == collection_name]
+                for key, value in filter.items():
+                    # decrypt then check key: JSON filter after decryption
+                    where_clauses.append(
+                        pgcrypto_decrypt(
+                            DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
+                        )[key].astext
+                        == str(value)
+                    )
+                stmt = select(
+                    DocumentChunk.id,
+                    pgcrypto_decrypt(
+                        DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
+                    ).label("text"),
+                    pgcrypto_decrypt(
+                        DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
+                    ).label("vmetadata"),
+                ).where(*where_clauses)
+                if limit is not None:
+                    stmt = stmt.limit(limit)
+                results = self.session.execute(stmt).all()
+            else:
+                query = self.session.query(DocumentChunk).filter(
+                    DocumentChunk.collection_name == collection_name
+                )
 
-            for key, value in filter.items():
-                query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
+                for key, value in filter.items():
+                    query = query.filter(
+                        DocumentChunk.vmetadata[key].astext == str(value)
+                    )
 
-            if limit is not None:
-                query = query.limit(limit)
+                if limit is not None:
+                    query = query.limit(limit)
 
-            results = query.all()
+                results = query.all()
 
             if not results:
                 return None
@@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase):
         self, collection_name: str, limit: Optional[int] = None
     ) -> Optional[GetResult]:
         try:
-            query = self.session.query(DocumentChunk).filter(
-                DocumentChunk.collection_name == collection_name
-            )
-            if limit is not None:
-                query = query.limit(limit)
+            if PGVECTOR_PGCRYPTO:
+                stmt = select(
+                    DocumentChunk.id,
+                    pgcrypto_decrypt(
+                        DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
+                    ).label("text"),
+                    pgcrypto_decrypt(
+                        DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
+                    ).label("vmetadata"),
+                ).where(DocumentChunk.collection_name == collection_name)
+                if limit is not None:
+                    stmt = stmt.limit(limit)
+                results = self.session.execute(stmt).all()
+                ids = [[row.id for row in results]]
+                documents = [[row.text for row in results]]
+                metadatas = [[row.vmetadata for row in results]]
+            else:
 
-            results = query.all()
+                query = self.session.query(DocumentChunk).filter(
+                    DocumentChunk.collection_name == collection_name
+                )
+                if limit is not None:
+                    query = query.limit(limit)
 
-            if not results:
-                return None
+                results = query.all()
 
-            ids = [[result.id for result in results]]
-            documents = [[result.text for result in results]]
-            metadatas = [[result.vmetadata for result in results]]
+                if not results:
+                    return None
+
+                ids = [[result.id for result in results]]
+                documents = [[result.text for result in results]]
+                metadatas = [[result.vmetadata for result in results]]
 
             return GetResult(ids=ids, documents=documents, metadatas=metadatas)
         except Exception as e:
@@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase):
         filter: Optional[Dict[str, Any]] = None,
     ) -> None:
         try:
-            query = self.session.query(DocumentChunk).filter(
-                DocumentChunk.collection_name == collection_name
-            )
-            if ids:
-                query = query.filter(DocumentChunk.id.in_(ids))
-            if filter:
-                for key, value in filter.items():
-                    query = query.filter(
-                        DocumentChunk.vmetadata[key].astext == str(value)
-                    )
-            deleted = query.delete(synchronize_session=False)
+            if PGVECTOR_PGCRYPTO:
+                wheres = [DocumentChunk.collection_name == collection_name]
+                if ids:
+                    wheres.append(DocumentChunk.id.in_(ids))
+                if filter:
+                    for key, value in filter.items():
+                        wheres.append(
+                            pgcrypto_decrypt(
+                                DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
+                            )[key].astext
+                            == str(value)
+                        )
+                stmt = DocumentChunk.__table__.delete().where(*wheres)
+                result = self.session.execute(stmt)
+                deleted = result.rowcount
+            else:
+                query = self.session.query(DocumentChunk).filter(
+                    DocumentChunk.collection_name == collection_name
+                )
+                if ids:
+                    query = query.filter(DocumentChunk.id.in_(ids))
+                if filter:
+                    for key, value in filter.items():
+                        query = query.filter(
+                            DocumentChunk.vmetadata[key].astext == str(value)
+                        )
+                deleted = query.delete(synchronize_session=False)
             self.session.commit()
             log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
         except Exception as e: