wrappers.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import logging
  2. import os
  3. from contextvars import ContextVar
  4. from open_webui.env import SRC_LOG_LEVELS
  5. from peewee import *
  6. from peewee import InterfaceError as PeeWeeInterfaceError
  7. from peewee import PostgresqlDatabase
  8. from playhouse.db_url import connect, parse
  9. from playhouse.shortcuts import ReconnectMixin
  10. from playhouse.sqlcipher_ext import SqlCipherDatabase
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["DB"])
  13. db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
  14. db_state = ContextVar("db_state", default=db_state_default.copy())
  15. class PeeweeConnectionState(object):
  16. def __init__(self, **kwargs):
  17. super().__setattr__("_state", db_state)
  18. super().__init__(**kwargs)
  19. def __setattr__(self, name, value):
  20. self._state.get()[name] = value
  21. def __getattr__(self, name):
  22. value = self._state.get()[name]
  23. return value
  24. class CustomReconnectMixin(ReconnectMixin):
  25. reconnect_errors = (
  26. # psycopg2
  27. (OperationalError, "termin"),
  28. (InterfaceError, "closed"),
  29. # peewee
  30. (PeeWeeInterfaceError, "closed"),
  31. )
  32. class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
  33. pass
  34. def register_connection(db_url):
  35. # Check if using SQLCipher protocol
  36. if db_url.startswith("sqlite+sqlcipher://"):
  37. database_password = os.environ.get("DATABASE_PASSWORD")
  38. if not database_password or database_password.strip() == "":
  39. raise ValueError(
  40. "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
  41. )
  42. # Parse the database path from SQLCipher URL
  43. # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
  44. db_path = db_url.replace("sqlite+sqlcipher://", "")
  45. if db_path.startswith("/"):
  46. db_path = db_path[1:] # Remove leading slash for relative paths
  47. # Use Peewee's native SqlCipherDatabase with encryption
  48. db = SqlCipherDatabase(db_path, passphrase=database_password)
  49. db.autoconnect = True
  50. db.reuse_if_open = True
  51. log.info("Connected to encrypted SQLite database using SQLCipher")
  52. else:
  53. # Standard database connection (existing logic)
  54. db = connect(db_url, unquote_user=True, unquote_password=True)
  55. if isinstance(db, PostgresqlDatabase):
  56. # Enable autoconnect for SQLite databases, managed by Peewee
  57. db.autoconnect = True
  58. db.reuse_if_open = True
  59. log.info("Connected to PostgreSQL database")
  60. # Get the connection details
  61. connection = parse(db_url, unquote_user=True, unquote_password=True)
  62. # Use our custom database class that supports reconnection
  63. db = ReconnectingPostgresqlDatabase(**connection)
  64. db.connect(reuse_if_open=True)
  65. elif isinstance(db, SqliteDatabase):
  66. # Enable autoconnect for SQLite databases, managed by Peewee
  67. db.autoconnect = True
  68. db.reuse_if_open = True
  69. log.info("Connected to SQLite database")
  70. else:
  71. raise ValueError("Unsupported database connection")
  72. return db