wrappers.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import logging
  2. import os
  3. from contextvars import ContextVar
  4. from open_webui.env import SRC_LOG_LEVELS
  5. from peewee import *
  6. from peewee import InterfaceError as PeeWeeInterfaceError
  7. from peewee import PostgresqlDatabase
  8. from playhouse.db_url import connect, parse
  9. from playhouse.shortcuts import ReconnectMixin
  10. from playhouse.sqlcipher_ext import SqlCipherDatabase
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["DB"])
  13. db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
  14. db_state = ContextVar("db_state", default=db_state_default.copy())
  15. class PeeweeConnectionState(object):
  16. def __init__(self, **kwargs):
  17. super().__setattr__("_state", db_state)
  18. super().__init__(**kwargs)
  19. def __setattr__(self, name, value):
  20. self._state.get()[name] = value
  21. def __getattr__(self, name):
  22. value = self._state.get()[name]
  23. return value
  24. class CustomReconnectMixin(ReconnectMixin):
  25. reconnect_errors = (
  26. # psycopg2
  27. (OperationalError, "termin"),
  28. (InterfaceError, "closed"),
  29. # peewee
  30. (PeeWeeInterfaceError, "closed"),
  31. )
  32. class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
  33. pass
  34. def register_connection(db_url):
  35. # Check if using SQLCipher protocol
  36. if db_url.startswith('sqlite+sqlcipher://'):
  37. database_password = os.environ.get("DATABASE_PASSWORD")
  38. if not database_password or database_password.strip() == "":
  39. raise ValueError("DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs")
  40. # Parse the database path from SQLCipher URL
  41. # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
  42. db_path = db_url.replace('sqlite+sqlcipher://', '')
  43. if db_path.startswith('/'):
  44. db_path = db_path[1:] # Remove leading slash for relative paths
  45. # Use Peewee's native SqlCipherDatabase with encryption
  46. db = SqlCipherDatabase(db_path, passphrase=database_password)
  47. db.autoconnect = True
  48. db.reuse_if_open = True
  49. log.info("Connected to encrypted SQLite database using SQLCipher")
  50. else:
  51. # Standard database connection (existing logic)
  52. db = connect(db_url, unquote_user=True, unquote_password=True)
  53. if isinstance(db, PostgresqlDatabase):
  54. # Enable autoconnect for SQLite databases, managed by Peewee
  55. db.autoconnect = True
  56. db.reuse_if_open = True
  57. log.info("Connected to PostgreSQL database")
  58. # Get the connection details
  59. connection = parse(db_url, unquote_user=True, unquote_password=True)
  60. # Use our custom database class that supports reconnection
  61. db = ReconnectingPostgresqlDatabase(**connection)
  62. db.connect(reuse_if_open=True)
  63. elif isinstance(db, SqliteDatabase):
  64. # Enable autoconnect for SQLite databases, managed by Peewee
  65. db.autoconnect = True
  66. db.reuse_if_open = True
  67. log.info("Connected to SQLite database")
  68. else:
  69. raise ValueError("Unsupported database connection")
  70. return db