123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- import inspect
- from urllib.parse import urlparse
- import logging
- import redis
- from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
- log = logging.getLogger(__name__)
- _CONNECTION_CACHE = {}
- class SentinelRedisProxy:
- def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
- self._sentinel = sentinel
- self._service = service
- self._kw = kw
- self._async_mode = async_mode
- def _master(self):
- return self._sentinel.master_for(self._service, **self._kw)
- def __getattr__(self, item):
- master = self._master()
- orig_attr = getattr(master, item)
- if not callable(orig_attr):
- return orig_attr
- FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
- if item in FACTORY_METHODS:
- return orig_attr
- if self._async_mode:
- async def _wrapped(*args, **kwargs):
- for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
- try:
- method = getattr(self._master(), item)
- result = method(*args, **kwargs)
- if inspect.iscoroutine(result):
- return await result
- return result
- except (
- redis.exceptions.ConnectionError,
- redis.exceptions.ReadOnlyError,
- ) as e:
- if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
- log.debug(
- "Redis sentinel fail-over (%s). Retry %s/%s",
- type(e).__name__,
- i + 1,
- REDIS_SENTINEL_MAX_RETRY_COUNT,
- )
- continue
- log.error(
- "Redis operation failed after %s retries: %s",
- REDIS_SENTINEL_MAX_RETRY_COUNT,
- e,
- )
- raise e from e
- return _wrapped
- else:
- def _wrapped(*args, **kwargs):
- for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
- try:
- method = getattr(self._master(), item)
- return method(*args, **kwargs)
- except (
- redis.exceptions.ConnectionError,
- redis.exceptions.ReadOnlyError,
- ) as e:
- if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
- log.debug(
- "Redis sentinel fail-over (%s). Retry %s/%s",
- type(e).__name__,
- i + 1,
- REDIS_SENTINEL_MAX_RETRY_COUNT,
- )
- continue
- log.error(
- "Redis operation failed after %s retries: %s",
- REDIS_SENTINEL_MAX_RETRY_COUNT,
- e,
- )
- raise e from e
- return _wrapped
- def parse_redis_service_url(redis_url):
- parsed_url = urlparse(redis_url)
- if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
- raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
- return {
- "username": parsed_url.username or None,
- "password": parsed_url.password or None,
- "service": parsed_url.hostname or "mymaster",
- "port": parsed_url.port or 6379,
- "db": int(parsed_url.path.lstrip("/") or 0),
- }
- def get_redis_connection(
- redis_url,
- redis_sentinels,
- redis_cluster=False,
- async_mode=False,
- decode_responses=True,
- ):
- cache_key = (
- redis_url,
- tuple(redis_sentinels) if redis_sentinels else (),
- async_mode,
- decode_responses,
- )
- if cache_key in _CONNECTION_CACHE:
- return _CONNECTION_CACHE[cache_key]
- connection = None
- if async_mode:
- import redis.asyncio as redis
- # If using sentinel in async mode
- if redis_sentinels:
- redis_config = parse_redis_service_url(redis_url)
- sentinel = redis.sentinel.Sentinel(
- redis_sentinels,
- port=redis_config["port"],
- db=redis_config["db"],
- username=redis_config["username"],
- password=redis_config["password"],
- decode_responses=decode_responses,
- )
- connection = SentinelRedisProxy(
- sentinel,
- redis_config["service"],
- async_mode=async_mode,
- )
- elif redis_cluster:
- if not redis_url:
- raise ValueError("Redis URL must be provided for cluster mode.")
- return redis.cluster.RedisCluster.from_url(
- redis_url, decode_responses=decode_responses
- )
- elif redis_url:
- connection = redis.from_url(redis_url, decode_responses=decode_responses)
- else:
- import redis
- if redis_sentinels:
- redis_config = parse_redis_service_url(redis_url)
- sentinel = redis.sentinel.Sentinel(
- redis_sentinels,
- port=redis_config["port"],
- db=redis_config["db"],
- username=redis_config["username"],
- password=redis_config["password"],
- decode_responses=decode_responses,
- )
- connection = SentinelRedisProxy(
- sentinel,
- redis_config["service"],
- async_mode=async_mode,
- )
- elif redis_cluster:
- if not redis_url:
- raise ValueError("Redis URL must be provided for cluster mode.")
- return redis.cluster.RedisCluster.from_url(
- redis_url, decode_responses=decode_responses
- )
- elif redis_url:
- connection = redis.Redis.from_url(
- redis_url, decode_responses=decode_responses
- )
- _CONNECTION_CACHE[cache_key] = connection
- return connection
- def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
- if sentinel_hosts_env:
- sentinel_hosts = sentinel_hosts_env.split(",")
- sentinel_port = int(sentinel_port_env)
- return [(host, sentinel_port) for host in sentinel_hosts]
- return []
- def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
- redis_config = parse_redis_service_url(redis_url)
- username = redis_config["username"] or ""
- password = redis_config["password"] or ""
- auth_part = ""
- if username or password:
- auth_part = f"{username}:{password}@"
- hosts_part = ",".join(
- f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
- )
- return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"
|