Ver Fonte

Merge pull request #11148 from Ithanil/redis_sentinel

feat: Support for Redis Sentinel
Timothy Jaeryang Baek há 1 mês atrás
pai
commit
7fce3e2a9c

+ 5 - 3
backend/open_webui/config.py

@@ -19,6 +19,8 @@ from open_webui.env import (
     DATABASE_URL,
     ENV,
     REDIS_URL,
+    REDIS_SENTINEL_HOSTS,
+    REDIS_SENTINEL_PORT,
     FRONTEND_BUILD_DIR,
     OFFLINE_MODE,
     OPEN_WEBUI_DIR,
@@ -28,7 +30,7 @@ from open_webui.env import (
     log,
 )
 from open_webui.internal.db import Base, get_db
-
+from open_webui.utils.redis import get_redis_connection
 
 class EndpointFilter(logging.Filter):
     def filter(self, record: logging.LogRecord) -> bool:
@@ -252,11 +254,11 @@ class AppConfig:
     _state: dict[str, PersistentConfig]
     _redis: Optional[redis.Redis] = None
 
-    def __init__(self, redis_url: Optional[str] = None):
+    def __init__(self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = []):
         super().__setattr__("_state", {})
         if redis_url:
             super().__setattr__(
-                "_redis", redis.Redis.from_url(redis_url, decode_responses=True)
+                "_redis", get_redis_connection(redis_url, redis_sentinels, decode_responses=True)
             )
 
     def __setattr__(self, key, value):

+ 6 - 0
backend/open_webui/env.py

@@ -323,6 +323,8 @@ ENABLE_REALTIME_CHAT_SAVE = (
 ####################################
 
 REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
+REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
 
 ####################################
 # WEBUI_AUTH (Required for security)
@@ -379,6 +381,10 @@ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
 WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
 WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
 
+WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
+
+WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
+
 AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
 
 if AIOHTTP_CLIENT_TIMEOUT == "":

+ 6 - 1
backend/open_webui/main.py

@@ -317,6 +317,8 @@ from open_webui.env import (
     AUDIT_LOG_LEVEL,
     CHANGELOG,
     REDIS_URL,
+    REDIS_SENTINEL_HOSTS,
+    REDIS_SENTINEL_PORT,
     GLOBAL_LOG_LEVEL,
     MAX_BODY_LOG_SIZE,
     SAFE_MODE,
@@ -360,6 +362,9 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
 
 from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
 
+from open_webui.utils.redis import get_sentinels_from_env
+
+
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
     Functions.deactivate_all_functions()
@@ -423,7 +428,7 @@ app = FastAPI(
 
 oauth_manager = OAuthManager(app)
 
-app.state.config = AppConfig(redis_url=REDIS_URL)
+app.state.config = AppConfig(redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT))
 
 app.state.WEBUI_NAME = WEBUI_NAME
 app.state.LICENSE_METADATA = None

+ 15 - 4
backend/open_webui/socket/main.py

@@ -3,16 +3,20 @@ import socketio
 import logging
 import sys
 import time
+from redis import asyncio as aioredis
 
 from open_webui.models.users import Users, UserNameResponse
 from open_webui.models.channels import Channels
 from open_webui.models.chats import Chats
+from open_webui.utils.redis import parse_redis_sentinel_url, get_sentinels_from_env, AsyncRedisSentinelManager
 
 from open_webui.env import (
     ENABLE_WEBSOCKET_SUPPORT,
     WEBSOCKET_MANAGER,
     WEBSOCKET_REDIS_URL,
     WEBSOCKET_REDIS_LOCK_TIMEOUT,
+    WEBSOCKET_SENTINEL_PORT,
+    WEBSOCKET_SENTINEL_HOSTS,
 )
 from open_webui.utils.auth import decode_token
 from open_webui.socket.utils import RedisDict, RedisLock
@@ -29,7 +33,12 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
 
 
 if WEBSOCKET_MANAGER == "redis":
-    mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
+    if WEBSOCKET_SENTINEL_HOSTS:
+        redis_config = parse_redis_sentinel_url(WEBSOCKET_REDIS_URL)
+        mgr = AsyncRedisSentinelManager(WEBSOCKET_SENTINEL_HOSTS.split(','), sentinel_port=int(WEBSOCKET_SENTINEL_PORT), redis_port=redis_config["port"],
+                                        service=redis_config["service"], db=redis_config["db"], username=redis_config["username"], password=redis_config["password"])
+    else:
+        mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
     sio = socketio.AsyncServer(
         cors_allowed_origins=[],
         async_mode="asgi",
@@ -55,14 +64,16 @@ TIMEOUT_DURATION = 3
 
 if WEBSOCKET_MANAGER == "redis":
     log.debug("Using Redis to manage websockets.")
-    SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
-    USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
-    USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
+    redis_sentinels=get_sentinels_from_env(WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT)
+    SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels)
+    USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels)
+    USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels)
 
     clean_up_lock = RedisLock(
         redis_url=WEBSOCKET_REDIS_URL,
         lock_name="usage_cleanup_lock",
         timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
+        redis_sentinels=redis_sentinels,
     )
     aquire_func = clean_up_lock.aquire_lock
     renew_func = clean_up_lock.renew_lock

+ 5 - 6
backend/open_webui/socket/utils.py

@@ -1,15 +1,14 @@
 import json
-import redis
 import uuid
-
+from open_webui.utils.redis import get_redis_connection
 
 class RedisLock:
-    def __init__(self, redis_url, lock_name, timeout_secs):
+    def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
         self.lock_name = lock_name
         self.lock_id = str(uuid.uuid4())
         self.timeout_secs = timeout_secs
         self.lock_obtained = False
-        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+        self.redis = get_redis_connection(redis_url, redis_sentinels, decode_responses=True)
 
     def aquire_lock(self):
         # nx=True will only set this key if it _hasn't_ already been set
@@ -31,9 +30,9 @@ class RedisLock:
 
 
 class RedisDict:
-    def __init__(self, name, redis_url):
+    def __init__(self, name, redis_url, redis_sentinels=[]):
         self.name = name
-        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+        self.redis = get_redis_connection(redis_url, redis_sentinels, decode_responses=True)
 
     def __setitem__(self, key, value):
         serialized_value = json.dumps(value)

+ 91 - 0
backend/open_webui/utils/redis.py

@@ -0,0 +1,91 @@
+import socketio
+import redis
+from redis import asyncio as aioredis
+from urllib.parse import urlparse
+
+def parse_redis_sentinel_url(redis_url):
+    parsed_url = urlparse(redis_url)
+    if parsed_url.scheme != "redis":
+        raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
+
+    return {
+        "username": parsed_url.username or None,
+        "password": parsed_url.password or None,
+        "service": parsed_url.hostname or 'mymaster',
+        "port": parsed_url.port or 6379,
+        "db": int(parsed_url.path.lstrip("/") or 0),
+    }
+
+def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
+    if redis_sentinels:
+        redis_config = parse_redis_sentinel_url(redis_url)
+        sentinel = redis.sentinel.Sentinel(
+            redis_sentinels,
+            port=redis_config['port'],
+            db=redis_config['db'],
+            username=redis_config['username'],
+            password=redis_config['password'],
+            decode_responses=decode_responses
+        )
+
+        # Get a master connection from Sentinel
+        return sentinel.master_for(redis_config['service'])
+    else:
+        # Standard Redis connection
+        return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
+
+def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
+    if sentinel_hosts_env:
+        sentinel_hosts=sentinel_hosts_env.split(',')
+        sentinel_port=int(sentinel_port_env)
+        return [(host, sentinel_port) for host in sentinel_hosts]
+    return []
+
+class AsyncRedisSentinelManager(socketio.AsyncRedisManager):
+    def __init__(self, sentinel_hosts, sentinel_port=26379, redis_port=6379, service="mymaster", db=0,
+                 username=None, password=None, channel='socketio', write_only=False, logger=None, redis_options=None):
+        """
+        Initialize the Redis Sentinel Manager.
+        This implementation mostly replicates the __init__ of AsyncRedisManager and
+        overrides _redis_connect() with a version that uses Redis Sentinel
+
+        :param sentinel_hosts: List of Sentinel hosts
+        :param sentinel_port: Sentinel Port
+        :param redis_port: Redis Port (currently unsupported by aioredis!)
+        :param service: Master service name in Sentinel
+        :param db: Redis database to use
+        :param username: Redis username (if any) (currently unsupported by aioredis!)
+        :param password: Redis password (if any)
+        :param channel: The channel name on which the server sends and receives
+                        notifications. Must be the same in all the servers.
+        :param write_only: If set to ``True``, only initialize to emit events. The
+                           default of ``False`` initializes the class for emitting
+                           and receiving.
+        :param redis_options: additional keyword arguments to be passed to
+                              ``aioredis.from_url()``.
+        """
+        self._sentinels = [(host, sentinel_port) for host in sentinel_hosts]
+        self._redis_port=redis_port
+        self._service = service
+        self._db = db
+        self._username = username
+        self._password = password
+        self._channel = channel
+        self.redis_options = redis_options or {}
+
+        # connect and call grandparent constructor
+        self._redis_connect()
+        super(socketio.AsyncRedisManager, self).__init__(channel=channel, write_only=write_only, logger=logger)
+
+    def _redis_connect(self):
+        """Establish connections to Redis through Sentinel."""
+        sentinel = aioredis.sentinel.Sentinel(
+            self._sentinels,
+            port=self._redis_port,
+            db=self._db,
+            password=self._password,
+            **self.redis_options
+        )
+
+        self.redis = sentinel.master_for(self._service)
+        self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)