redis.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import inspect
  2. from urllib.parse import urlparse
  3. import logging
  4. import redis
  5. from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
  6. log = logging.getLogger(__name__)
  7. _CONNECTION_CACHE = {}
  8. class SentinelRedisProxy:
  9. def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
  10. self._sentinel = sentinel
  11. self._service = service
  12. self._kw = kw
  13. self._async_mode = async_mode
  14. def _master(self):
  15. return self._sentinel.master_for(self._service, **self._kw)
  16. def __getattr__(self, item):
  17. master = self._master()
  18. orig_attr = getattr(master, item)
  19. if not callable(orig_attr):
  20. return orig_attr
  21. FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
  22. if item in FACTORY_METHODS:
  23. return orig_attr
  24. if self._async_mode:
  25. async def _wrapped(*args, **kwargs):
  26. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  27. try:
  28. method = getattr(self._master(), item)
  29. result = method(*args, **kwargs)
  30. if inspect.iscoroutine(result):
  31. return await result
  32. return result
  33. except (
  34. redis.exceptions.ConnectionError,
  35. redis.exceptions.ReadOnlyError,
  36. ) as e:
  37. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  38. log.debug(
  39. "Redis sentinel fail-over (%s). Retry %s/%s",
  40. type(e).__name__,
  41. i + 1,
  42. REDIS_SENTINEL_MAX_RETRY_COUNT,
  43. )
  44. continue
  45. log.error(
  46. "Redis operation failed after %s retries: %s",
  47. REDIS_SENTINEL_MAX_RETRY_COUNT,
  48. e,
  49. )
  50. raise e from e
  51. return _wrapped
  52. else:
  53. def _wrapped(*args, **kwargs):
  54. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  55. try:
  56. method = getattr(self._master(), item)
  57. return method(*args, **kwargs)
  58. except (
  59. redis.exceptions.ConnectionError,
  60. redis.exceptions.ReadOnlyError,
  61. ) as e:
  62. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  63. log.debug(
  64. "Redis sentinel fail-over (%s). Retry %s/%s",
  65. type(e).__name__,
  66. i + 1,
  67. REDIS_SENTINEL_MAX_RETRY_COUNT,
  68. )
  69. continue
  70. log.error(
  71. "Redis operation failed after %s retries: %s",
  72. REDIS_SENTINEL_MAX_RETRY_COUNT,
  73. e,
  74. )
  75. raise e from e
  76. return _wrapped
  77. def parse_redis_service_url(redis_url):
  78. parsed_url = urlparse(redis_url)
  79. if parsed_url.scheme != "redis":
  80. raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
  81. return {
  82. "username": parsed_url.username or None,
  83. "password": parsed_url.password or None,
  84. "service": parsed_url.hostname or "mymaster",
  85. "port": parsed_url.port or 6379,
  86. "db": int(parsed_url.path.lstrip("/") or 0),
  87. }
  88. def get_redis_connection(
  89. redis_url, redis_sentinels, async_mode=False, decode_responses=True
  90. ):
  91. cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses)
  92. if cache_key in _CONNECTION_CACHE:
  93. return _CONNECTION_CACHE[cache_key]
  94. connection = None
  95. if async_mode:
  96. import redis.asyncio as redis
  97. # If using sentinel in async mode
  98. if redis_sentinels:
  99. redis_config = parse_redis_service_url(redis_url)
  100. sentinel = redis.sentinel.Sentinel(
  101. redis_sentinels,
  102. port=redis_config["port"],
  103. db=redis_config["db"],
  104. username=redis_config["username"],
  105. password=redis_config["password"],
  106. decode_responses=decode_responses,
  107. )
  108. connection = SentinelRedisProxy(
  109. sentinel,
  110. redis_config["service"],
  111. async_mode=async_mode,
  112. )
  113. elif redis_url:
  114. connection = redis.from_url(redis_url, decode_responses=decode_responses)
  115. else:
  116. import redis
  117. if redis_sentinels:
  118. redis_config = parse_redis_service_url(redis_url)
  119. sentinel = redis.sentinel.Sentinel(
  120. redis_sentinels,
  121. port=redis_config["port"],
  122. db=redis_config["db"],
  123. username=redis_config["username"],
  124. password=redis_config["password"],
  125. decode_responses=decode_responses,
  126. )
  127. connection = SentinelRedisProxy(
  128. sentinel,
  129. redis_config["service"],
  130. async_mode=async_mode,
  131. )
  132. elif redis_url:
  133. connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
  134. _CONNECTION_CACHE[cache_key] = connection
  135. return connection
  136. def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
  137. if sentinel_hosts_env:
  138. sentinel_hosts = sentinel_hosts_env.split(",")
  139. sentinel_port = int(sentinel_port_env)
  140. return [(host, sentinel_port) for host in sentinel_hosts]
  141. return []
  142. def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
  143. redis_config = parse_redis_service_url(redis_url)
  144. username = redis_config["username"] or ""
  145. password = redis_config["password"] or ""
  146. auth_part = ""
  147. if username or password:
  148. auth_part = f"{username}:{password}@"
  149. hosts_part = ",".join(
  150. f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
  151. )
  152. return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"