1
0

db.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import os
  2. import json
  3. import logging
  4. from contextlib import contextmanager
  5. from typing import Any, Optional
  6. from open_webui.internal.wrappers import register_connection
  7. from open_webui.env import (
  8. OPEN_WEBUI_DIR,
  9. DATABASE_URL,
  10. DATABASE_SCHEMA,
  11. SRC_LOG_LEVELS,
  12. DATABASE_POOL_MAX_OVERFLOW,
  13. DATABASE_POOL_RECYCLE,
  14. DATABASE_POOL_SIZE,
  15. DATABASE_POOL_TIMEOUT,
  16. DATABASE_ENABLE_SQLITE_WAL,
  17. )
  18. from peewee_migrate import Router
  19. from sqlalchemy import Dialect, create_engine, MetaData, event, types
  20. from sqlalchemy.ext.declarative import declarative_base
  21. from sqlalchemy.orm import scoped_session, sessionmaker
  22. from sqlalchemy.pool import QueuePool, NullPool
  23. from sqlalchemy.sql.type_api import _T
  24. from typing_extensions import Self
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["DB"])
  27. class JSONField(types.TypeDecorator):
  28. impl = types.Text
  29. cache_ok = True
  30. def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
  31. return json.dumps(value)
  32. def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
  33. if value is not None:
  34. return json.loads(value)
  35. def copy(self, **kw: Any) -> Self:
  36. return JSONField(self.impl.length)
  37. def db_value(self, value):
  38. return json.dumps(value)
  39. def python_value(self, value):
  40. if value is not None:
  41. return json.loads(value)
  42. # Workaround to handle the peewee migration
  43. # This is required to ensure the peewee migration is handled before the alembic migration
  44. def handle_peewee_migration(DATABASE_URL):
  45. # db = None
  46. try:
  47. # Replace the postgresql:// with postgres:// to handle the peewee migration
  48. db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
  49. migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
  50. router = Router(db, logger=log, migrate_dir=migrate_dir)
  51. router.run()
  52. db.close()
  53. except Exception as e:
  54. log.error(f"Failed to initialize the database connection: {e}")
  55. log.warning(
  56. "Hint: If your database password contains special characters, you may need to URL-encode it."
  57. )
  58. raise
  59. finally:
  60. # Properly closing the database connection
  61. if db and not db.is_closed():
  62. db.close()
  63. # Assert if db connection has been closed
  64. assert db.is_closed(), "Database connection is still open."
  65. handle_peewee_migration(DATABASE_URL)
  66. SQLALCHEMY_DATABASE_URL = DATABASE_URL
  67. # Handle SQLCipher URLs
  68. if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
  69. database_password = os.environ.get("DATABASE_PASSWORD")
  70. if not database_password or database_password.strip() == "":
  71. raise ValueError(
  72. "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
  73. )
  74. # Extract database path from SQLCipher URL
  75. db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
  76. if db_path.startswith("/"):
  77. db_path = db_path[1:] # Remove leading slash for relative paths
  78. # Create a custom creator function that uses sqlcipher3
  79. def create_sqlcipher_connection():
  80. import sqlcipher3
  81. conn = sqlcipher3.connect(db_path, check_same_thread=False)
  82. conn.execute(f"PRAGMA key = '{database_password}'")
  83. return conn
  84. engine = create_engine(
  85. "sqlite://", # Dummy URL since we're using creator
  86. creator=create_sqlcipher_connection,
  87. echo=False,
  88. )
  89. log.info("Connected to encrypted SQLite database using SQLCipher")
  90. elif "sqlite" in SQLALCHEMY_DATABASE_URL:
  91. engine = create_engine(
  92. SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
  93. )
  94. def on_connect(dbapi_connection, connection_record):
  95. cursor = dbapi_connection.cursor()
  96. if DATABASE_ENABLE_SQLITE_WAL:
  97. cursor.execute("PRAGMA journal_mode=WAL")
  98. else:
  99. cursor.execute("PRAGMA journal_mode=DELETE")
  100. cursor.close()
  101. event.listen(engine, "connect", on_connect)
  102. else:
  103. if isinstance(DATABASE_POOL_SIZE, int):
  104. if DATABASE_POOL_SIZE > 0:
  105. engine = create_engine(
  106. SQLALCHEMY_DATABASE_URL,
  107. pool_size=DATABASE_POOL_SIZE,
  108. max_overflow=DATABASE_POOL_MAX_OVERFLOW,
  109. pool_timeout=DATABASE_POOL_TIMEOUT,
  110. pool_recycle=DATABASE_POOL_RECYCLE,
  111. pool_pre_ping=True,
  112. poolclass=QueuePool,
  113. )
  114. else:
  115. engine = create_engine(
  116. SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
  117. )
  118. else:
  119. engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
  120. SessionLocal = sessionmaker(
  121. autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
  122. )
  123. metadata_obj = MetaData(schema=DATABASE_SCHEMA)
  124. Base = declarative_base(metadata=metadata_obj)
  125. Session = scoped_session(SessionLocal)
  126. def get_session():
  127. db = SessionLocal()
  128. try:
  129. yield db
  130. finally:
  131. db.close()
  132. get_db = contextmanager(get_session)