Browse Source

Merge pull request #15863 from tcx4c70/feat/sqlite-wal

perf(db): Improve performance of db, especially sqlite
Tim Jaeryang Baek 1 month ago
parent
commit
6a109e972e

+ 13 - 0
backend/open_webui/env.py

@@ -339,6 +339,19 @@ else:
     except Exception:
         DATABASE_POOL_RECYCLE = 3600
 
+DATABASE_ENABLE_SQLITE_WAL = (os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true")
+
+DATABASE_DEDUPLICATE_INTERVAL = (
+    os.environ.get("DATABASE_DEDUPLICATE_INTERVAL", 0.)
+)
+if DATABASE_DEDUPLICATE_INTERVAL == "":
+    DATABASE_DEDUPLICATE_INTERVAL = 0.0
+else:
+    try:
+        DATABASE_DEDUPLICATE_INTERVAL = float(DATABASE_DEDUPLICATE_INTERVAL)
+    except Exception:
+        DATABASE_DEDUPLICATE_INTERVAL = 0.0
+
 RESET_CONFIG_ON_START = (
     os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
 )

+ 12 - 1
backend/open_webui/internal/db.py

@@ -14,9 +14,10 @@ from open_webui.env import (
     DATABASE_POOL_RECYCLE,
     DATABASE_POOL_SIZE,
     DATABASE_POOL_TIMEOUT,
+    DATABASE_ENABLE_SQLITE_WAL,
 )
 from peewee_migrate import Router
-from sqlalchemy import Dialect, create_engine, MetaData, types
+from sqlalchemy import Dialect, create_engine, MetaData, event, types
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import scoped_session, sessionmaker
 from sqlalchemy.pool import QueuePool, NullPool
@@ -114,6 +115,16 @@ elif "sqlite" in SQLALCHEMY_DATABASE_URL:
     engine = create_engine(
         SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
     )
+
+    def on_connect(dbapi_connection, connection_record):
+        cursor = dbapi_connection.cursor()
+        if DATABASE_ENABLE_SQLITE_WAL:
+            cursor.execute("PRAGMA journal_mode=WAL")
+        else:
+            cursor.execute("PRAGMA journal_mode=DELETE")
+        cursor.close()
+
+    event.listen(engine, "connect", on_connect)
 else:
     if isinstance(DATABASE_POOL_SIZE, int):
         if DATABASE_POOL_SIZE > 0:

+ 3 - 0
backend/open_webui/models/users.py

@@ -4,8 +4,10 @@ from typing import Optional
 from open_webui.internal.db import Base, JSONField, get_db
 
 
+from open_webui.env import DATABASE_DEDUPLICATE_INTERVAL
 from open_webui.models.chats import Chats
 from open_webui.models.groups import Groups
+from open_webui.utils.misc import deduplicate
 
 
 from pydantic import BaseModel, ConfigDict
@@ -311,6 +313,7 @@ class UsersTable:
         except Exception:
             return None
 
+    @deduplicate(DATABASE_DEDUPLICATE_INTERVAL)
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         try:
             with get_db() as db:

+ 41 - 0
backend/open_webui/utils/misc.py

@@ -1,5 +1,6 @@
 import hashlib
 import re
+import threading
 import time
 import uuid
 import logging
@@ -478,3 +479,43 @@ def convert_logit_bias_input_to_json(user_input):
         bias = 100 if bias > 100 else -100 if bias < -100 else bias
         logit_bias_json[token] = bias
     return json.dumps(logit_bias_json)
+
+
+def freeze(value):
+    """
+    Freeze a value to make it hashable.
+    """
+    if isinstance(value, dict):
+        return frozenset((k, freeze(v)) for k, v in value.items())
+    elif isinstance(value, list):
+        return tuple(freeze(v) for v in value)
+    return value
+
+
+def deduplicate(interval: float = 10.0):
+    """
+    Decorator to prevent a function from being called more than once within a specified duration.
+    If the function is called again within the duration, it returns None. To avoid returning
+    different types, the return type of the function should be Optional[T].
+
+    :param interval: Duration in seconds to wait before allowing the function to be called again.
+    """
+
+    def decorator(func):
+        last_calls = {}
+        lock = threading.Lock()
+
+        def wrapper(*args, **kwargs):
+            key = (args, freeze(kwargs))
+            now = time.time()
+            if now - last_calls.get(key, 0) < interval:
+                return None
+            with lock:
+                if now - last_calls.get(key, 0) < interval:
+                    return None
+                last_calls[key] = now
+            return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator