Bläddra i källkod

first part of adding Redis Sentinel support

Jan Kessler 2 månader sedan
förälder
incheckning
3b357746d8
2 ändrade filer med 72 tillägg och 3 borttagningar
  1. 58 2
      backend/open_webui/socket/main.py
  2. 14 1
      backend/open_webui/socket/utils.py

+ 58 - 2
backend/open_webui/socket/main.py

@@ -3,6 +3,7 @@ import socketio
 import logging
 import sys
 import time
+from redis.sentinel import Sentinel
 
 from open_webui.models.users import Users, UserNameResponse
 from open_webui.models.channels import Channels
@@ -13,15 +14,65 @@ from open_webui.env import (
     WEBSOCKET_MANAGER,
     WEBSOCKET_REDIS_URL,
     WEBSOCKET_REDIS_LOCK_TIMEOUT,
+    WEBSOCKET_SENTINEL_PORT,
+    WEBSOCKET_SENTINEL_HOSTS,
 )
 from open_webui.utils.auth import decode_token
-from open_webui.socket.utils import RedisDict, RedisLock
+from open_webui.socket.utils import RedisDict, RedisLock, parse_redis_sentinel_url
 
 from open_webui.env import (
     GLOBAL_LOG_LEVEL,
     SRC_LOG_LEVELS,
 )
 
+class AsyncRedisSentinelManager(socketio.AsyncRedisManager):
+    def __init__(self, sentinel_hosts, sentinel_port=26379, redis_port=6379, service_name="mymaster", db=0,
+                 username=None, password=None, channel='socketio', write_only=False, **kwargs):
+        """
+        Initialize the Redis Sentinel Manager.
+
+        :param sentinel_hosts: List of Sentinel hosts
+        :param sentinel_port: Sentinel Port
+        :param redis_port: Redis Port
+        :param service_name: Master service name in Sentinel
+        :param db: Redis database to use
+        :param username: Redis username (if any)
+        :param password: Redis password (if any)
+        :param channel: The Redis channel name
+        :param write_only: If set to True, only initialize the connection to send messages
+        :param kwargs: Additional connection arguments for Redis
+        """
+        self.sentinel_addresses = [(host, sentinel_port) for host in sentinel_hosts]
+        self.redis_port=redis_port
+        self.service_name = service_name
+        self.db = db
+        self.username = username
+        self.password = password
+        self.channel = channel
+        self.write_only = write_only
+        self.redis_kwargs = kwargs
+
+        # Skip parent's init but call grandparent's init
+        socketio.AsyncManager.__init__(self)
+        self._redis_connect()
+
+    def _redis_connect(self):
+        """Establish connections to Redis through Sentinel."""
+        sentinel = redis.sentinel.Sentinel(
+            self.sentinel_addresses,
+            port=self.redis_port,
+            db=self.db,
+            username=self.username,
+            password=self.password,
+            **self.redis_kwargs
+        )
+
+        # Get connections to the Redis master and slave
+        self.redis = sentinel.master_for(self.service_name)
+        if not self.write_only:
+            self.pubsub = sentinel.slave_for(self.service_name).pubsub()
+            self.pubsub.subscribe(self.channel)
+
 
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
 log = logging.getLogger(__name__)
@@ -29,7 +80,12 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
 
 
 if WEBSOCKET_MANAGER == "redis":
-    mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
+    if WEBSOCKET_SENTINEL_HOSTS:
+        redis_config = parse_redis_sentinel_url(WEBSOCKET_REDIS_URL)
+        mgr = AsyncRedisSentinelManager(WEBSOCKET_SENTINEL_HOSTS.split(','), sentinel_port=int(WEBSOCKET_SENTINEL_PORT), redis_port=redis_config["port"],
+                                        service=redis_config["service"], db=redis_config["db"], username=redis_config["username"], password=redis_config["password"])
+    else:
+        mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
     sio = socketio.AsyncServer(
         cors_allowed_origins=[],
         async_mode="asgi",

+ 14 - 1
backend/open_webui/socket/utils.py

@@ -1,7 +1,20 @@
 import json
 import redis
 import uuid
-
+from urllib.parse import urlparse
+
+def parse_redis_sentinel_url(redis_url):
+    parsed_url = urlparse(redis_url)
+    if parsed_url.scheme != "redis":
+        raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
+
+    return {
+        "username": parsed_url.username or None,
+        "password": parsed_url.password or None,
+        "service": parsed_url.hostname or 'mymaster',
+        "port": parsed_url.port or 6379,
+        "db": int(parsed_url.path.lstrip("/") or 0),
+    }
 
 class RedisLock:
     def __init__(self, redis_url, lock_name, timeout_secs):