wrappers.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import logging
  2. import os
  3. from contextvars import ContextVar
  4. from open_webui.env import SRC_LOG_LEVELS
  5. from peewee import *
  6. from peewee import InterfaceError as PeeWeeInterfaceError
  7. from peewee import PostgresqlDatabase
  8. from playhouse.db_url import connect, parse
  9. from playhouse.shortcuts import ReconnectMixin
  10. log = logging.getLogger(__name__)
  11. log.setLevel(SRC_LOG_LEVELS["DB"])
  12. db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
  13. db_state = ContextVar("db_state", default=db_state_default.copy())
  14. class PeeweeConnectionState(object):
  15. def __init__(self, **kwargs):
  16. super().__setattr__("_state", db_state)
  17. super().__init__(**kwargs)
  18. def __setattr__(self, name, value):
  19. self._state.get()[name] = value
  20. def __getattr__(self, name):
  21. value = self._state.get()[name]
  22. return value
  23. class CustomReconnectMixin(ReconnectMixin):
  24. reconnect_errors = (
  25. # psycopg2
  26. (OperationalError, "termin"),
  27. (InterfaceError, "closed"),
  28. # peewee
  29. (PeeWeeInterfaceError, "closed"),
  30. )
  31. class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
  32. pass
  33. def register_connection(db_url):
  34. # Check if using SQLCipher protocol
  35. if db_url.startswith("sqlite+sqlcipher://"):
  36. database_password = os.environ.get("DATABASE_PASSWORD")
  37. if not database_password or database_password.strip() == "":
  38. raise ValueError(
  39. "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
  40. )
  41. from playhouse.sqlcipher_ext import SqlCipherDatabase
  42. # Parse the database path from SQLCipher URL
  43. # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
  44. db_path = db_url.replace("sqlite+sqlcipher://", "")
  45. if db_path.startswith("/"):
  46. db_path = db_path[1:] # Remove leading slash for relative paths
  47. # Use Peewee's native SqlCipherDatabase with encryption
  48. db = SqlCipherDatabase(db_path, passphrase=database_password)
  49. db.autoconnect = True
  50. db.reuse_if_open = True
  51. log.info("Connected to encrypted SQLite database using SQLCipher")
  52. else:
  53. # Standard database connection (existing logic)
  54. db = connect(db_url, unquote_user=True, unquote_password=True)
  55. if isinstance(db, PostgresqlDatabase):
  56. # Enable autoconnect for SQLite databases, managed by Peewee
  57. db.autoconnect = True
  58. db.reuse_if_open = True
  59. log.info("Connected to PostgreSQL database")
  60. # Get the connection details
  61. connection = parse(db_url, unquote_user=True, unquote_password=True)
  62. # Use our custom database class that supports reconnection
  63. db = ReconnectingPostgresqlDatabase(**connection)
  64. db.connect(reuse_if_open=True)
  65. elif isinstance(db, SqliteDatabase):
  66. # Enable autoconnect for SQLite databases, managed by Peewee
  67. db.autoconnect = True
  68. db.reuse_if_open = True
  69. log.info("Connected to SQLite database")
  70. else:
  71. raise ValueError("Unsupported database connection")
  72. return db