Jelajahi Sumber

chore: format

Timothy Jaeryang Baek 2 bulan lalu
induk
melakukan
77189664c2

+ 12 - 9
backend/open_webui/internal/db.py

@@ -82,29 +82,32 @@ handle_peewee_migration(DATABASE_URL)
 SQLALCHEMY_DATABASE_URL = DATABASE_URL
 
 # Handle SQLCipher URLs
-if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
+if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
     database_password = os.environ.get("DATABASE_PASSWORD")
     if not database_password or database_password.strip() == "":
-        raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
-    
+        raise ValueError(
+            "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
+        )
+
     # Extract database path from SQLCipher URL
-    db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
-    if db_path.startswith('/'):
+    db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
+    if db_path.startswith("/"):
         db_path = db_path[1:]  # Remove leading slash for relative paths
-    
+
     # Create a custom creator function that uses sqlcipher3
     def create_sqlcipher_connection():
         import sqlcipher3
+
         conn = sqlcipher3.connect(db_path, check_same_thread=False)
         conn.execute(f"PRAGMA key = '{database_password}'")
         return conn
-    
+
     engine = create_engine(
         "sqlite://",  # Dummy URL since we're using creator
         creator=create_sqlcipher_connection,
-        echo=False
+        echo=False,
     )
-    
+
     log.info("Connected to encrypted SQLite database using SQLCipher")
 
 elif "sqlite" in SQLALCHEMY_DATABASE_URL:

+ 9 - 7
backend/open_webui/internal/wrappers.py

@@ -46,23 +46,25 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
 
 def register_connection(db_url):
     # Check if using SQLCipher protocol
-    if db_url.startswith('sqlite+sqlcipher://'):
+    if db_url.startswith("sqlite+sqlcipher://"):
         database_password = os.environ.get("DATABASE_PASSWORD")
         if not database_password or database_password.strip() == "":
-            raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
-        
+            raise ValueError(
+                "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
+            )
+
         # Parse the database path from SQLCipher URL
         # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
-        db_path = db_url.replace('sqlite+sqlcipher://', '')
-        if db_path.startswith('/'):
+        db_path = db_url.replace("sqlite+sqlcipher://", "")
+        if db_path.startswith("/"):
             db_path = db_path[1:]  # Remove leading slash for relative paths
-        
+
         # Use Peewee's native SqlCipherDatabase with encryption
         db = SqlCipherDatabase(db_path, passphrase=database_password)
         db.autoconnect = True
         db.reuse_if_open = True
         log.info("Connected to encrypted SQLite database using SQLCipher")
-        
+
     else:
         # Standard database connection (existing logic)
         db = connect(db_url, unquote_user=True, unquote_password=True)

+ 11 - 8
backend/open_webui/migrations/env.py

@@ -63,26 +63,29 @@ def run_migrations_online() -> None:
 
     """
     # Handle SQLCipher URLs
-    if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
+    if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
         if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
-            raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
-        
+            raise ValueError(
+                "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
+            )
+
         # Extract database path from SQLCipher URL
-        db_path = DB_URL.replace('sqlite+sqlcipher://', '')
-        if db_path.startswith('/'):
+        db_path = DB_URL.replace("sqlite+sqlcipher://", "")
+        if db_path.startswith("/"):
             db_path = db_path[1:]  # Remove leading slash for relative paths
-        
+
         # Create a custom creator function that uses sqlcipher3
         def create_sqlcipher_connection():
             import sqlcipher3
+
             conn = sqlcipher3.connect(db_path, check_same_thread=False)
             conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
             return conn
-        
+
         connectable = create_engine(
             "sqlite://",  # Dummy URL since we're using creator
             creator=create_sqlcipher_connection,
-            echo=False
+            echo=False,
         )
     else:
         # Standard database connection (existing logic)

+ 4 - 4
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -421,7 +421,7 @@ class PgvectorClient(VectorDBBase):
                 documents[qid].append(row.text)
                 metadatas[qid].append(row.vmetadata)
 
-            self.session.rollback() # read-only transaction
+            self.session.rollback()  # read-only transaction
             return SearchResult(
                 ids=ids, distances=distances, documents=documents, metadatas=metadatas
             )
@@ -479,7 +479,7 @@ class PgvectorClient(VectorDBBase):
             documents = [[result.text for result in results]]
             metadatas = [[result.vmetadata for result in results]]
 
-            self.session.rollback() # read-only transaction
+            self.session.rollback()  # read-only transaction
             return GetResult(
                 ids=ids,
                 documents=documents,
@@ -527,7 +527,7 @@ class PgvectorClient(VectorDBBase):
                 documents = [[result.text for result in results]]
                 metadatas = [[result.vmetadata for result in results]]
 
-            self.session.rollback() # read-only transaction
+            self.session.rollback()  # read-only transaction
             return GetResult(ids=ids, documents=documents, metadatas=metadatas)
         except Exception as e:
             self.session.rollback()
@@ -598,7 +598,7 @@ class PgvectorClient(VectorDBBase):
                 .first()
                 is not None
             )
-            self.session.rollback() # read-only transaction
+            self.session.rollback()  # read-only transaction
             return exists
         except Exception as e:
             self.session.rollback()

+ 5 - 1
backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -60,7 +60,11 @@ class QdrantClient(VectorDBBase):
                 timeout=self.QDRANT_TIMEOUT,
             )
         else:
-            self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,)
+            self.client = Qclient(
+                url=self.QDRANT_URI,
+                api_key=self.QDRANT_API_KEY,
+                timeout=QDRANT_TIMEOUT,
+            )
 
     def _result_to_get_result(self, points) -> GetResult:
         ids = []

+ 5 - 1
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py

@@ -76,7 +76,11 @@ class QdrantClient(VectorDBBase):
                 timeout=self.QDRANT_TIMEOUT,
             )
             if self.PREFER_GRPC
-            else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,)
+            else Qclient(
+                url=self.QDRANT_URI,
+                api_key=self.QDRANT_API_KEY,
+                timeout=self.QDRANT_TIMEOUT,
+            )
         )
 
         # Main collection types for multi-tenancy