浏览代码

enh/refac: pgvector pool support

Timothy Jaeryang Baek 2 月之前
父节点
当前提交
b4f04ff3a7
共有 2 个文件被更改,包括 62 次插入4 次删除
  1. 39 0
      backend/open_webui/config.py
  2. 23 4
      backend/open_webui/retrieval/vector/dbs/pgvector.py

+ 39 - 0
backend/open_webui/config.py

@@ -1886,6 +1886,45 @@ if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
         "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
     )
 
+
+PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None)
+
+if PGVECTOR_POOL_SIZE != None:
+    try:
+        PGVECTOR_POOL_SIZE = int(PGVECTOR_POOL_SIZE)
+    except Exception:
+        PGVECTOR_POOL_SIZE = None
+
+PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0)
+
+if PGVECTOR_POOL_MAX_OVERFLOW == "":
+    PGVECTOR_POOL_MAX_OVERFLOW = 0
+else:
+    try:
+        PGVECTOR_POOL_MAX_OVERFLOW = int(PGVECTOR_POOL_MAX_OVERFLOW)
+    except Exception:
+        PGVECTOR_POOL_MAX_OVERFLOW = 0
+
+PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30)
+
+if PGVECTOR_POOL_TIMEOUT == "":
+    PGVECTOR_POOL_TIMEOUT = 30
+else:
+    try:
+        PGVECTOR_POOL_TIMEOUT = int(PGVECTOR_POOL_TIMEOUT)
+    except Exception:
+        PGVECTOR_POOL_TIMEOUT = 30
+
+PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600)
+
+if PGVECTOR_POOL_RECYCLE == "":
+    PGVECTOR_POOL_RECYCLE = 3600
+else:
+    try:
+        PGVECTOR_POOL_RECYCLE = int(PGVECTOR_POOL_RECYCLE)
+    except Exception:
+        PGVECTOR_POOL_RECYCLE = 3600
+
 # Pinecone
 PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
 PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)

+ 23 - 4
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -18,7 +18,7 @@ from sqlalchemy import (
     values,
 )
 from sqlalchemy.sql import true
-from sqlalchemy.pool import NullPool
+from sqlalchemy.pool import NullPool, QueuePool
 
 from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
 from sqlalchemy.dialects.postgresql import JSONB, array
@@ -37,6 +37,10 @@ from open_webui.config import (
     PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
     PGVECTOR_PGCRYPTO,
     PGVECTOR_PGCRYPTO_KEY,
+    PGVECTOR_POOL_SIZE,
+    PGVECTOR_POOL_MAX_OVERFLOW,
+    PGVECTOR_POOL_TIMEOUT,
+    PGVECTOR_POOL_RECYCLE,
 )
 
 from open_webui.env import SRC_LOG_LEVELS
@@ -80,9 +84,24 @@ class PgvectorClient(VectorDBBase):
 
             self.session = Session
         else:
-            engine = create_engine(
-                PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
-            )
+            if isinstance(PGVECTOR_POOL_SIZE, int):
+                if PGVECTOR_POOL_SIZE > 0:
+                    engine = create_engine(
+                        PGVECTOR_DB_URL,
+                        pool_size=PGVECTOR_POOL_SIZE,
+                        max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
+                        pool_timeout=PGVECTOR_POOL_TIMEOUT,
+                        pool_recycle=PGVECTOR_POOL_RECYCLE,
+                        pool_pre_ping=True,
+                        poolclass=QueuePool,
+                    )
+                else:
+                    engine = create_engine(
+                        PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
+                    )
+            else:
+                engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
+
             SessionLocal = sessionmaker(
                 autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
             )