Browse Source

second part of adding Redis Sentinel support

Jan Kessler 4 months ago
parent
commit
4370332e32
2 changed files with 33 additions and 7 deletions
  1. 7 3
      backend/open_webui/socket/main.py
  2. 26 4
      backend/open_webui/socket/utils.py

+ 7 - 3
backend/open_webui/socket/main.py

@@ -111,14 +111,18 @@ TIMEOUT_DURATION = 3
 
 if WEBSOCKET_MANAGER == "redis":
     log.debug("Using Redis to manage websockets.")
-    SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
-    USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
-    USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
+    sentinel_hosts=WEBSOCKET_SENTINEL_HOSTS.split(',')
+    sentinel_port=int(WEBSOCKET_SENTINEL_PORT)
+    sentinels=[(host, sentinel_port) for host in sentinel_hosts]
+    SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels)
+    USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels)
+    USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL, sentinels)
 
     clean_up_lock = RedisLock(
         redis_url=WEBSOCKET_REDIS_URL,
         lock_name="usage_cleanup_lock",
         timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
+        sentinels,
     )
     aquire_func = clean_up_lock.aquire_lock
     renew_func = clean_up_lock.renew_lock

+ 26 - 4
backend/open_webui/socket/utils.py

@@ -16,13 +16,35 @@ def parse_redis_sentinel_url(redis_url):
         "db": int(parsed_url.path.lstrip("/") or 0),
     }
 
+def get_redis_connection(redis_url, sentinels, decode_responses=True):
+    """
+    Creates a Redis connection from either a standard Redis URL or uses special
+    parsing to setup a Sentinel connection, if given an array of host/port tuples.
+    """
+    if sentinels:
+        redis_config = parse_redis_sentinel_url(redis_url)
+        sentinel = redis.sentinel.Sentinel(
+            self.sentinels,
+            port=redis_config['port'],
+            db=redis_config['db'],
+            username=redis_config['username'],
+            password=redis_config['password'],
+            decode_responses=decode_responses
+        }
+
+        # Get a master connection from Sentinel
+        return sentinel.master_for(redis_config['service'])
+    else:
+        # Standard Redis connection
+        return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
+
 class RedisLock:
-    def __init__(self, redis_url, lock_name, timeout_secs):
+    def __init__(self, redis_url, lock_name, timeout_secs, sentinels=[]):
         self.lock_name = lock_name
         self.lock_id = str(uuid.uuid4())
         self.timeout_secs = timeout_secs
         self.lock_obtained = False
-        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+        self.redis = get_redis_connection(redis_url, sentinels, decode_responses=True)
 
     def aquire_lock(self):
         # nx=True will only set this key if it _hasn't_ already been set
@@ -44,9 +66,9 @@ class RedisLock:
 
 
 class RedisDict:
-    def __init__(self, name, redis_url):
+    def __init__(self, name, redis_url, sentinels=[]):
         self.name = name
-        self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
+        self.redis = get_redis_connection(redis_url, sentinels, decode_responses=True)
 
     def __setitem__(self, key, value):
         serialized_value = json.dumps(value)