Browse Source

fix peewee and playhouse connections to retry

perf3ct 1 year ago
parent
commit
10fa887eab
2 changed files with 40 additions and 37 deletions
  1. 18 10
      backend/apps/webui/internal/db.py
  2. 22 27
      backend/apps/webui/internal/wrappers.py

+ 18 - 10
backend/apps/webui/internal/db.py

@@ -4,15 +4,13 @@ import json
 
 from peewee import *
 from peewee_migrate import Router
-from playhouse.db_url import connect
 
-from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases
+from apps.webui.internal.wrappers import register_connection
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
-
 class JSONField(TextField):
     def db_value(self, value):
         return json.dumps(value)
@@ -21,9 +19,6 @@ class JSONField(TextField):
         if value is not None:
             return json.loads(value)
 
-
-register_peewee_databases()
-
 # Check if the file exists
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file
@@ -32,13 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
 else:
     pass
 
-DB = connect(DATABASE_URL)
-DB._state = PeeweeConnectionState()
-log.info(f"Connected to a {DB.__class__.__name__} database.")
+
+# The `register_connection` function encapsulates the logic for setting up 
+# the database connection based on the connection string, while `connect` 
+# is a Peewee-specific method to manage the connection state and avoid errors 
+# when a connection is already open.
+try:
+    DB = register_connection(DATABASE_URL)
+    log.info(f"Connected to a {DB.__class__.__name__} database.")
+except Exception as e:
+    log.error(f"Failed to initialize the database connection: {e}")
+    raise
+
 router = Router(
     DB,
     migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
     logger=log,
 )
 router.run()
-DB.connect(reuse_if_open=True)
+try:
+    DB.connect()
+except OperationalError as e:
+    log.info(f"Failed to connect to database again due to: {e}")
+    pass

+ 22 - 27
backend/apps/webui/internal/wrappers.py

@@ -1,18 +1,13 @@
 from contextvars import ContextVar
-
-from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, _ConnectionState
-from playhouse.db_url import register_database
+from peewee import *
+from playhouse.db_url import connect
 from playhouse.pool import PooledPostgresqlDatabase
 from playhouse.shortcuts import ReconnectMixin
-from psycopg2 import OperationalError
-from psycopg2.errors import InterfaceError
-
 
 db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
 db_state = ContextVar("db_state", default=db_state_default.copy())
 
-
-class PeeweeConnectionState(_ConnectionState):
+class PeeweeConnectionState(object):
     def __init__(self, **kwargs):
         super().__setattr__("_state", db_state)
         super().__init__(**kwargs)
@@ -21,29 +16,29 @@ class PeeweeConnectionState(_ConnectionState):
         self._state.get()[name] = value
 
     def __getattr__(self, name):
-        return self._state.get()[name]
-
+        value = self._state.get()[name]
+        return value
 
-class CustomReconnectMixin(ReconnectMixin):
-    reconnect_errors = (
-        # default ReconnectMixin exceptions
-        *ReconnectMixin.reconnect_errors,
-        # psycopg2
-        (OperationalError, 'termin'),
-        (InterfaceError, 'closed'),
-        # peewee
-        (PeeWeeInterfaceError, 'closed'),
-    )
-
-
-class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+class ReconnectingPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase):
     pass
 
+class ReconnectingPooledPostgresqlDatabase(ReconnectMixin, PooledPostgresqlDatabase):
+    pass
 
-class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
+class ReconnectingSqliteDatabase(ReconnectMixin, SqliteDatabase):
     pass
 
 
-def register_peewee_databases():
-    register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
-    register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')
+def register_connection(db_url):
+    # Connect using the playhouse.db_url module, which supports multiple 
+    # database types, then wrap the connection in a ReconnectMixin to handle dropped connections
+    db = connect(db_url)
+    if isinstance(db, PostgresqlDatabase):
+        db = ReconnectingPostgresqlDatabase(db.database, **db.connect_params)
+    elif isinstance(db, PooledPostgresqlDatabase):
+        db = ReconnectingPooledPostgresqlDatabase(db.database, **db.connect_params)
+    elif isinstance(db, SqliteDatabase):
+        db = ReconnectingSqliteDatabase(db.database, **db.connect_params)
+    else:
+        raise ValueError('Unsupported database connection')
+    return db