|
@@ -1,12 +1,16 @@
|
|
|
from typing import Optional, List, Dict, Any
|
|
|
import logging
|
|
|
+import json
|
|
|
from sqlalchemy import (
|
|
|
+ func,
|
|
|
+ literal,
|
|
|
cast,
|
|
|
column,
|
|
|
create_engine,
|
|
|
Column,
|
|
|
Integer,
|
|
|
MetaData,
|
|
|
+ LargeBinary,
|
|
|
select,
|
|
|
text,
|
|
|
Text,
|
|
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
|
|
|
SearchResult,
|
|
|
GetResult,
|
|
|
)
|
|
|
-from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
|
|
+from open_webui.config import (
|
|
|
+ PGVECTOR_DB_URL,
|
|
|
+ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
|
|
+ PGVECTOR_PGCRYPTO,
|
|
|
+ PGVECTOR_PGCRYPTO_KEY,
|
|
|
+)
|
|
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
|
|
|
|
|
|
+def pgcrypto_encrypt(val, key):
|
|
|
+ return func.pgp_sym_encrypt(val, literal(key))
|
|
|
+
|
|
|
+
|
|
|
+def pgcrypto_decrypt(col, key, outtype="text"):
|
|
|
+ return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
|
|
+
|
|
|
+
|
|
|
class DocumentChunk(Base):
|
|
|
__tablename__ = "document_chunk"
|
|
|
|
|
|
id = Column(Text, primary_key=True)
|
|
|
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
|
|
collection_name = Column(Text, nullable=False)
|
|
|
- text = Column(Text, nullable=True)
|
|
|
- vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
|
|
+
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ text = Column(LargeBinary, nullable=True)
|
|
|
+ vmetadata = Column(LargeBinary, nullable=True)
|
|
|
+ else:
|
|
|
+ text = Column(Text, nullable=True)
|
|
|
+ vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
|
|
|
|
|
|
|
|
class PgvectorClient(VectorDBBase):
|
|
@@ -147,22 +169,52 @@ class PgvectorClient(VectorDBBase):
|
|
|
|
|
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
|
try:
|
|
|
- new_items = []
|
|
|
- for item in items:
|
|
|
- vector = self.adjust_vector_length(item["vector"])
|
|
|
- new_chunk = DocumentChunk(
|
|
|
- id=item["id"],
|
|
|
- vector=vector,
|
|
|
- collection_name=collection_name,
|
|
|
- text=item["text"],
|
|
|
- vmetadata=item["metadata"],
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ for item in items:
|
|
|
+ vector = self.adjust_vector_length(item["vector"])
|
|
|
+ # Use raw SQL for BYTEA/pgcrypto
|
|
|
+ self.session.execute(
|
|
|
+ text(
|
|
|
+ """
|
|
|
+ INSERT INTO document_chunk
|
|
|
+ (id, vector, collection_name, text, vmetadata)
|
|
|
+ VALUES (
|
|
|
+ :id, :vector, :collection_name,
|
|
|
+ pgp_sym_encrypt(:text, :key),
|
|
|
+ pgp_sym_encrypt(:metadata::text, :key)
|
|
|
+ )
|
|
|
+ ON CONFLICT (id) DO NOTHING
|
|
|
+ """
|
|
|
+ ),
|
|
|
+ {
|
|
|
+ "id": item["id"],
|
|
|
+ "vector": vector,
|
|
|
+ "collection_name": collection_name,
|
|
|
+ "text": item["text"],
|
|
|
+ "metadata": json.dumps(item["metadata"]),
|
|
|
+ "key": PGVECTOR_PGCRYPTO_KEY,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.session.commit()
|
|
|
+ log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
|
|
+
|
|
|
+ else:
|
|
|
+ new_items = []
|
|
|
+ for item in items:
|
|
|
+ vector = self.adjust_vector_length(item["vector"])
|
|
|
+ new_chunk = DocumentChunk(
|
|
|
+ id=item["id"],
|
|
|
+ vector=vector,
|
|
|
+ collection_name=collection_name,
|
|
|
+ text=item["text"],
|
|
|
+ vmetadata=item["metadata"],
|
|
|
+ )
|
|
|
+ new_items.append(new_chunk)
|
|
|
+ self.session.bulk_save_objects(new_items)
|
|
|
+ self.session.commit()
|
|
|
+ log.info(
|
|
|
+ f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
|
|
)
|
|
|
- new_items.append(new_chunk)
|
|
|
- self.session.bulk_save_objects(new_items)
|
|
|
- self.session.commit()
|
|
|
- log.info(
|
|
|
- f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
|
|
- )
|
|
|
except Exception as e:
|
|
|
self.session.rollback()
|
|
|
log.exception(f"Error during insert: {e}")
|
|
@@ -170,33 +222,65 @@ class PgvectorClient(VectorDBBase):
|
|
|
|
|
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
|
try:
|
|
|
- for item in items:
|
|
|
- vector = self.adjust_vector_length(item["vector"])
|
|
|
- existing = (
|
|
|
- self.session.query(DocumentChunk)
|
|
|
- .filter(DocumentChunk.id == item["id"])
|
|
|
- .first()
|
|
|
- )
|
|
|
- if existing:
|
|
|
- existing.vector = vector
|
|
|
- existing.text = item["text"]
|
|
|
- existing.vmetadata = item["metadata"]
|
|
|
- existing.collection_name = (
|
|
|
- collection_name # Update collection_name if necessary
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ for item in items:
|
|
|
+ vector = self.adjust_vector_length(item["vector"])
|
|
|
+ self.session.execute(
|
|
|
+ text(
|
|
|
+ """
|
|
|
+ INSERT INTO document_chunk
|
|
|
+ (id, vector, collection_name, text, vmetadata)
|
|
|
+ VALUES (
|
|
|
+ :id, :vector, :collection_name,
|
|
|
+ pgp_sym_encrypt(:text, :key),
|
|
|
+ pgp_sym_encrypt(:metadata::text, :key)
|
|
|
+ )
|
|
|
+ ON CONFLICT (id) DO UPDATE SET
|
|
|
+ vector = EXCLUDED.vector,
|
|
|
+ collection_name = EXCLUDED.collection_name,
|
|
|
+ text = EXCLUDED.text,
|
|
|
+ vmetadata = EXCLUDED.vmetadata
|
|
|
+ """
|
|
|
+ ),
|
|
|
+ {
|
|
|
+ "id": item["id"],
|
|
|
+ "vector": vector,
|
|
|
+ "collection_name": collection_name,
|
|
|
+ "text": item["text"],
|
|
|
+ "metadata": json.dumps(item["metadata"]),
|
|
|
+ "key": PGVECTOR_PGCRYPTO_KEY,
|
|
|
+ },
|
|
|
)
|
|
|
- else:
|
|
|
- new_chunk = DocumentChunk(
|
|
|
- id=item["id"],
|
|
|
- vector=vector,
|
|
|
- collection_name=collection_name,
|
|
|
- text=item["text"],
|
|
|
- vmetadata=item["metadata"],
|
|
|
+ self.session.commit()
|
|
|
+ log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
|
|
+ else:
|
|
|
+ for item in items:
|
|
|
+ vector = self.adjust_vector_length(item["vector"])
|
|
|
+ existing = (
|
|
|
+ self.session.query(DocumentChunk)
|
|
|
+ .filter(DocumentChunk.id == item["id"])
|
|
|
+ .first()
|
|
|
)
|
|
|
- self.session.add(new_chunk)
|
|
|
- self.session.commit()
|
|
|
- log.info(
|
|
|
- f"Upserted {len(items)} items into collection '{collection_name}'."
|
|
|
- )
|
|
|
+ if existing:
|
|
|
+ existing.vector = vector
|
|
|
+ existing.text = item["text"]
|
|
|
+ existing.vmetadata = item["metadata"]
|
|
|
+ existing.collection_name = (
|
|
|
+ collection_name # Update collection_name if necessary
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ new_chunk = DocumentChunk(
|
|
|
+ id=item["id"],
|
|
|
+ vector=vector,
|
|
|
+ collection_name=collection_name,
|
|
|
+ text=item["text"],
|
|
|
+ vmetadata=item["metadata"],
|
|
|
+ )
|
|
|
+ self.session.add(new_chunk)
|
|
|
+ self.session.commit()
|
|
|
+ log.info(
|
|
|
+ f"Upserted {len(items)} items into collection '{collection_name}'."
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
self.session.rollback()
|
|
|
log.exception(f"Error during upsert: {e}")
|
|
@@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
|
|
|
.alias("query_vectors")
|
|
|
)
|
|
|
|
|
|
+ result_fields = [
|
|
|
+ DocumentChunk.id,
|
|
|
+ ]
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ result_fields.append(
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
|
|
+ ).label("text")
|
|
|
+ )
|
|
|
+ result_fields.append(
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
|
|
+ ).label("vmetadata")
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ result_fields.append(DocumentChunk.text)
|
|
|
+ result_fields.append(DocumentChunk.vmetadata)
|
|
|
+ result_fields.append(
|
|
|
+ (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
|
|
+ "distance"
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
# Build the lateral subquery for each query vector
|
|
|
subq = (
|
|
|
- select(
|
|
|
- DocumentChunk.id,
|
|
|
- DocumentChunk.text,
|
|
|
- DocumentChunk.vmetadata,
|
|
|
- (
|
|
|
- DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
|
|
- ).label("distance"),
|
|
|
- )
|
|
|
+ select(*result_fields)
|
|
|
.where(DocumentChunk.collection_name == collection_name)
|
|
|
.order_by(
|
|
|
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
|
@@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase):
|
|
|
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
|
|
) -> Optional[GetResult]:
|
|
|
try:
|
|
|
- query = self.session.query(DocumentChunk).filter(
|
|
|
- DocumentChunk.collection_name == collection_name
|
|
|
- )
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ # Build where clause for vmetadata filter
|
|
|
+ where_clauses = [DocumentChunk.collection_name == collection_name]
|
|
|
+ for key, value in filter.items():
|
|
|
+ # decrypt then check key: JSON filter after decryption
|
|
|
+ where_clauses.append(
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
|
|
+ )[key].astext
|
|
|
+ == str(value)
|
|
|
+ )
|
|
|
+ stmt = select(
|
|
|
+ DocumentChunk.id,
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
|
|
+ ).label("text"),
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
|
|
+ ).label("vmetadata"),
|
|
|
+ ).where(*where_clauses)
|
|
|
+ if limit is not None:
|
|
|
+ stmt = stmt.limit(limit)
|
|
|
+ results = self.session.execute(stmt).all()
|
|
|
+ else:
|
|
|
+ query = self.session.query(DocumentChunk).filter(
|
|
|
+ DocumentChunk.collection_name == collection_name
|
|
|
+ )
|
|
|
|
|
|
- for key, value in filter.items():
|
|
|
- query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
|
|
+ for key, value in filter.items():
|
|
|
+ query = query.filter(
|
|
|
+ DocumentChunk.vmetadata[key].astext == str(value)
|
|
|
+ )
|
|
|
|
|
|
- if limit is not None:
|
|
|
- query = query.limit(limit)
|
|
|
+ if limit is not None:
|
|
|
+ query = query.limit(limit)
|
|
|
|
|
|
- results = query.all()
|
|
|
+ results = query.all()
|
|
|
|
|
|
if not results:
|
|
|
return None
|
|
@@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase):
|
|
|
self, collection_name: str, limit: Optional[int] = None
|
|
|
) -> Optional[GetResult]:
|
|
|
try:
|
|
|
- query = self.session.query(DocumentChunk).filter(
|
|
|
- DocumentChunk.collection_name == collection_name
|
|
|
- )
|
|
|
- if limit is not None:
|
|
|
- query = query.limit(limit)
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ stmt = select(
|
|
|
+ DocumentChunk.id,
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
|
|
+ ).label("text"),
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
|
|
+ ).label("vmetadata"),
|
|
|
+ ).where(DocumentChunk.collection_name == collection_name)
|
|
|
+ if limit is not None:
|
|
|
+ stmt = stmt.limit(limit)
|
|
|
+ results = self.session.execute(stmt).all()
|
|
|
+ ids = [[row.id for row in results]]
|
|
|
+ documents = [[row.text for row in results]]
|
|
|
+ metadatas = [[row.vmetadata for row in results]]
|
|
|
+ else:
|
|
|
|
|
|
- results = query.all()
|
|
|
+ query = self.session.query(DocumentChunk).filter(
|
|
|
+ DocumentChunk.collection_name == collection_name
|
|
|
+ )
|
|
|
+ if limit is not None:
|
|
|
+ query = query.limit(limit)
|
|
|
|
|
|
- if not results:
|
|
|
- return None
|
|
|
+ results = query.all()
|
|
|
|
|
|
- ids = [[result.id for result in results]]
|
|
|
- documents = [[result.text for result in results]]
|
|
|
- metadatas = [[result.vmetadata for result in results]]
|
|
|
+ if not results:
|
|
|
+ return None
|
|
|
+
|
|
|
+ ids = [[result.id for result in results]]
|
|
|
+ documents = [[result.text for result in results]]
|
|
|
+ metadatas = [[result.vmetadata for result in results]]
|
|
|
|
|
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
|
|
except Exception as e:
|
|
@@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase):
|
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
|
) -> None:
|
|
|
try:
|
|
|
- query = self.session.query(DocumentChunk).filter(
|
|
|
- DocumentChunk.collection_name == collection_name
|
|
|
- )
|
|
|
- if ids:
|
|
|
- query = query.filter(DocumentChunk.id.in_(ids))
|
|
|
- if filter:
|
|
|
- for key, value in filter.items():
|
|
|
- query = query.filter(
|
|
|
- DocumentChunk.vmetadata[key].astext == str(value)
|
|
|
- )
|
|
|
- deleted = query.delete(synchronize_session=False)
|
|
|
+ if PGVECTOR_PGCRYPTO:
|
|
|
+ wheres = [DocumentChunk.collection_name == collection_name]
|
|
|
+ if ids:
|
|
|
+ wheres.append(DocumentChunk.id.in_(ids))
|
|
|
+ if filter:
|
|
|
+ for key, value in filter.items():
|
|
|
+ wheres.append(
|
|
|
+ pgcrypto_decrypt(
|
|
|
+ DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
|
|
+ )[key].astext
|
|
|
+ == str(value)
|
|
|
+ )
|
|
|
+ stmt = DocumentChunk.__table__.delete().where(*wheres)
|
|
|
+ result = self.session.execute(stmt)
|
|
|
+ deleted = result.rowcount
|
|
|
+ else:
|
|
|
+ query = self.session.query(DocumentChunk).filter(
|
|
|
+ DocumentChunk.collection_name == collection_name
|
|
|
+ )
|
|
|
+ if ids:
|
|
|
+ query = query.filter(DocumentChunk.id.in_(ids))
|
|
|
+ if filter:
|
|
|
+ for key, value in filter.items():
|
|
|
+ query = query.filter(
|
|
|
+ DocumentChunk.vmetadata[key].astext == str(value)
|
|
|
+ )
|
|
|
+ deleted = query.delete(synchronize_session=False)
|
|
|
self.session.commit()
|
|
|
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
|
|
except Exception as e:
|