Browse Source

feat: add Redis Sentinel failover support for high availability

- Implement SentinelRedisProxy class with automatic master discovery
- Add retry logic for handling connection failures and read-only errors
- Support both async and sync Redis operations with Sentinel
- Ensure backward compatibility with existing Redis configurations
- Provide seamless failover during master node outages

This enhancement significantly improves system reliability by eliminating
single points of failure in Redis deployments and ensuring continuous
service availability during infrastructure issues.

Signed-off-by: Sihyeon Jang <sihyeon.jang@navercorp.com>
Sihyeon Jang 3 months ago
parent
commit
423d0923d9
1 changed files with 74 additions and 4 deletions
  1. 74 4
      backend/open_webui/utils/redis.py

+ 74 - 4
backend/open_webui/utils/redis.py

@@ -1,6 +1,68 @@
-import socketio
+import inspect
 from urllib.parse import urlparse
-from typing import Optional
+
+import redis
+
+
+MAX_RETRY_COUNT = 2
+
+class SentinelRedisProxy:
+    def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
+        self._sentinel = sentinel
+        self._service = service
+        self._kw = kw
+        self._async_mode = async_mode
+
+    def _master(self):
+        return self._sentinel.master_for(self._service, **self._kw)
+
+    def __getattr__(self, item):
+        master = self._master()
+        orig_attr = getattr(master, item)
+
+        if not callable(orig_attr):
+            return orig_attr
+
+        FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
+        if item in FACTORY_METHODS:
+            return orig_attr
+
+        if self._async_mode:
+
+            async def _wrapped(*args, **kwargs):
+                for i in range(MAX_RETRY_COUNT):
+                    try:
+                        method = getattr(self._master(), item)
+                        result = method(*args, **kwargs)
+                        if inspect.iscoroutine(result):
+                            return await result
+                        return result
+                    except (
+                        redis.exceptions.ConnectionError,
+                        redis.exceptions.ReadOnlyError,
+                    ) as e:
+                        if i < MAX_RETRY_COUNT - 1:
+                            continue
+                        raise e from e
+
+            return _wrapped
+
+        else:
+
+            def _wrapped(*args, **kwargs):
+                for i in range(MAX_RETRY_COUNT):
+                    try:
+                        method = getattr(self._master(), item)
+                        return method(*args, **kwargs)
+                    except (
+                        redis.exceptions.ConnectionError,
+                        redis.exceptions.ReadOnlyError,
+                    ) as e:
+                        if i < MAX_RETRY_COUNT - 1:
+                            continue
+                        raise e from e
+
+            return _wrapped
 
 
 def parse_redis_service_url(redis_url):
@@ -34,7 +96,11 @@ def get_redis_connection(
                 password=redis_config["password"],
                 decode_responses=decode_responses,
             )
-            return sentinel.master_for(redis_config["service"])
+            return SentinelRedisProxy(
+                sentinel,
+                redis_config["service"],
+                async_mode=async_mode,
+            )
         elif redis_url:
             return redis.from_url(redis_url, decode_responses=decode_responses)
         else:
@@ -52,7 +118,11 @@ def get_redis_connection(
                 password=redis_config["password"],
                 decode_responses=decode_responses,
             )
-            return sentinel.master_for(redis_config["service"])
+            return SentinelRedisProxy(
+                sentinel,
+                redis_config["service"],
+                async_mode=async_mode,
+            )
         elif redis_url:
             return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
         else: