Browse Source

enh/refac: redis cluster support

Timothy Jaeryang Baek 2 months ago
parent
commit
35400daf19

+ 9 - 3
backend/open_webui/config.py

@@ -7,7 +7,7 @@ import redis
 
 from datetime import datetime
 from pathlib import Path
-from typing import Generic, Optional, TypeVar
+from typing import Generic, Union, Optional, TypeVar
 from urllib.parse import urlparse
 
 import requests
@@ -213,13 +213,14 @@ class PersistentConfig(Generic[T]):
 
 class AppConfig:
     _state: dict[str, PersistentConfig]
-    _redis: Optional[redis.Redis] = None
+    _redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
     _redis_key_prefix: str
 
     def __init__(
         self,
         redis_url: Optional[str] = None,
         redis_sentinels: Optional[list] = [],
+        redis_cluster: Optional[bool] = False,
         redis_key_prefix: str = "open-webui",
     ):
         super().__setattr__("_state", {})
@@ -227,7 +228,12 @@ class AppConfig:
         if redis_url:
             super().__setattr__(
                 "_redis",
-                get_redis_connection(redis_url, redis_sentinels, decode_responses=True),
+                get_redis_connection(
+                    redis_url,
+                    redis_sentinels,
+                    redis_cluster,
+                    decode_responses=True,
+                ),
             )
 
     def __setattr__(self, key, value):

+ 7 - 1
backend/open_webui/env.py

@@ -346,7 +346,10 @@ ENABLE_REALTIME_CHAT_SAVE = (
 ####################################
 
 REDIS_URL = os.environ.get("REDIS_URL", "")
+REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
+
 REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
+
 REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
 REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
 
@@ -489,6 +492,9 @@ ENABLE_WEBSOCKET_SUPPORT = (
 WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
 
 WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
+WEBSOCKET_REDIS_CLUSTER = (
+    os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
+)
 
 websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
 
@@ -498,9 +504,9 @@ except ValueError:
     WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
 
 WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
-
 WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
 
+
 AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
 
 if AIOHTTP_CLIENT_TIMEOUT == "":

+ 3 - 0
backend/open_webui/main.py

@@ -399,6 +399,7 @@ from open_webui.env import (
     AUDIT_LOG_LEVEL,
     CHANGELOG,
     REDIS_URL,
+    REDIS_CLUSTER,
     REDIS_KEY_PREFIX,
     REDIS_SENTINEL_HOSTS,
     REDIS_SENTINEL_PORT,
@@ -525,6 +526,7 @@ async def lifespan(app: FastAPI):
         redis_sentinels=get_sentinels_from_env(
             REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
         ),
+        redis_cluster=REDIS_CLUSTER,
         async_mode=True,
     )
 
@@ -580,6 +582,7 @@ app.state.instance_id = None
 app.state.config = AppConfig(
     redis_url=REDIS_URL,
     redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
+    redis_cluster=REDIS_CLUSTER,
     redis_key_prefix=REDIS_KEY_PREFIX,
 )
 app.state.redis = None

+ 6 - 0
backend/open_webui/socket/main.py

@@ -22,6 +22,7 @@ from open_webui.env import (
     ENABLE_WEBSOCKET_SUPPORT,
     WEBSOCKET_MANAGER,
     WEBSOCKET_REDIS_URL,
+    WEBSOCKET_REDIS_CLUSTER,
     WEBSOCKET_REDIS_LOCK_TIMEOUT,
     WEBSOCKET_SENTINEL_PORT,
     WEBSOCKET_SENTINEL_HOSTS,
@@ -86,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
         redis_sentinels=get_sentinels_from_env(
             WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
         ),
+        redis_cluster=WEBSOCKET_REDIS_CLUSTER,
         async_mode=True,
     )
 
@@ -96,16 +98,19 @@ if WEBSOCKET_MANAGER == "redis":
         f"{REDIS_KEY_PREFIX}:session_pool",
         redis_url=WEBSOCKET_REDIS_URL,
         redis_sentinels=redis_sentinels,
+        redis_cluster=WEBSOCKET_REDIS_CLUSTER,
     )
     USER_POOL = RedisDict(
         f"{REDIS_KEY_PREFIX}:user_pool",
         redis_url=WEBSOCKET_REDIS_URL,
         redis_sentinels=redis_sentinels,
+        redis_cluster=WEBSOCKET_REDIS_CLUSTER,
     )
     USAGE_POOL = RedisDict(
         f"{REDIS_KEY_PREFIX}:usage_pool",
         redis_url=WEBSOCKET_REDIS_URL,
         redis_sentinels=redis_sentinels,
+        redis_cluster=WEBSOCKET_REDIS_CLUSTER,
     )
 
     clean_up_lock = RedisLock(
@@ -113,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
         lock_name="usage_cleanup_lock",
         timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
         redis_sentinels=redis_sentinels,
+        redis_cluster=WEBSOCKET_REDIS_CLUSTER,
     )
     aquire_func = clean_up_lock.aquire_lock
     renew_func = clean_up_lock.renew_lock

+ 18 - 4
backend/open_webui/socket/utils.py

@@ -7,13 +7,24 @@ import pycrdt as Y
 
 
 class RedisLock:
-    def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
+    def __init__(
+        self,
+        redis_url,
+        lock_name,
+        timeout_secs,
+        redis_sentinels=[],
+        redis_cluster=False,
+    ):
+
         self.lock_name = lock_name
         self.lock_id = str(uuid.uuid4())
         self.timeout_secs = timeout_secs
         self.lock_obtained = False
         self.redis = get_redis_connection(
-            redis_url, redis_sentinels, decode_responses=True
+            redis_url,
+            redis_sentinels,
+            redis_cluster=redis_cluster,
+            decode_responses=True,
         )
 
     def aquire_lock(self):
@@ -36,10 +47,13 @@ class RedisLock:
 
 
 class RedisDict:
-    def __init__(self, name, redis_url, redis_sentinels=[]):
+    def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
         self.name = name
         self.redis = get_redis_connection(
-            redis_url, redis_sentinels, decode_responses=True
+            redis_url,
+            redis_sentinels,
+            redis_cluster=redis_cluster,
+            decode_responses=True,
         )
 
     def __setitem__(self, key, value):

+ 28 - 5
backend/open_webui/utils/redis.py

@@ -96,8 +96,8 @@ class SentinelRedisProxy:
 
 def parse_redis_service_url(redis_url):
     parsed_url = urlparse(redis_url)
-    if parsed_url.scheme != "redis":
-        raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
+    if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
+        raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
 
     return {
         "username": parsed_url.username or None,
@@ -109,10 +109,19 @@ def parse_redis_service_url(redis_url):
 
 
 def get_redis_connection(
-    redis_url, redis_sentinels, async_mode=False, decode_responses=True
+    redis_url,
+    redis_sentinels,
+    redis_cluster=False,
+    async_mode=False,
+    decode_responses=True,
 ):
 
-    cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses)
+    cache_key = (
+        redis_url,
+        tuple(redis_sentinels) if redis_sentinels else (),
+        async_mode,
+        decode_responses,
+    )
 
     if cache_key in _CONNECTION_CACHE:
         return _CONNECTION_CACHE[cache_key]
@@ -138,6 +147,12 @@ def get_redis_connection(
                 redis_config["service"],
                 async_mode=async_mode,
             )
+        elif redis_cluster:
+            if not redis_url:
+                raise ValueError("Redis URL must be provided for cluster mode.")
+            return redis.cluster.RedisCluster.from_url(
+                redis_url, decode_responses=decode_responses
+            )
         elif redis_url:
             connection = redis.from_url(redis_url, decode_responses=decode_responses)
     else:
@@ -158,8 +173,16 @@ def get_redis_connection(
                 redis_config["service"],
                 async_mode=async_mode,
             )
+        elif redis_cluster:
+            if not redis_url:
+                raise ValueError("Redis URL must be provided for cluster mode.")
+            return redis.cluster.RedisCluster.from_url(
+                redis_url, decode_responses=decode_responses
+            )
         elif redis_url:
-            connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
+            connection = redis.Redis.from_url(
+                redis_url, decode_responses=decode_responses
+            )
 
     _CONNECTION_CACHE[cache_key] = connection
     return connection