redis.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import inspect
  2. from urllib.parse import urlparse
  3. import logging
  4. import redis
  5. from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
  6. log = logging.getLogger(__name__)
  7. _CONNECTION_CACHE = {}
  8. class SentinelRedisProxy:
  9. def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
  10. self._sentinel = sentinel
  11. self._service = service
  12. self._kw = kw
  13. self._async_mode = async_mode
  14. def _master(self):
  15. return self._sentinel.master_for(self._service, **self._kw)
  16. def __getattr__(self, item):
  17. master = self._master()
  18. orig_attr = getattr(master, item)
  19. if not callable(orig_attr):
  20. return orig_attr
  21. FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
  22. if item in FACTORY_METHODS:
  23. return orig_attr
  24. if self._async_mode:
  25. async def _wrapped(*args, **kwargs):
  26. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  27. try:
  28. method = getattr(self._master(), item)
  29. result = method(*args, **kwargs)
  30. if inspect.iscoroutine(result):
  31. return await result
  32. return result
  33. except (
  34. redis.exceptions.ConnectionError,
  35. redis.exceptions.ReadOnlyError,
  36. ) as e:
  37. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  38. log.debug(
  39. "Redis sentinel fail-over (%s). Retry %s/%s",
  40. type(e).__name__,
  41. i + 1,
  42. REDIS_SENTINEL_MAX_RETRY_COUNT,
  43. )
  44. continue
  45. log.error(
  46. "Redis operation failed after %s retries: %s",
  47. REDIS_SENTINEL_MAX_RETRY_COUNT,
  48. e,
  49. )
  50. raise e from e
  51. return _wrapped
  52. else:
  53. def _wrapped(*args, **kwargs):
  54. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  55. try:
  56. method = getattr(self._master(), item)
  57. return method(*args, **kwargs)
  58. except (
  59. redis.exceptions.ConnectionError,
  60. redis.exceptions.ReadOnlyError,
  61. ) as e:
  62. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  63. log.debug(
  64. "Redis sentinel fail-over (%s). Retry %s/%s",
  65. type(e).__name__,
  66. i + 1,
  67. REDIS_SENTINEL_MAX_RETRY_COUNT,
  68. )
  69. continue
  70. log.error(
  71. "Redis operation failed after %s retries: %s",
  72. REDIS_SENTINEL_MAX_RETRY_COUNT,
  73. e,
  74. )
  75. raise e from e
  76. return _wrapped
  77. def parse_redis_service_url(redis_url):
  78. parsed_url = urlparse(redis_url)
  79. if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
  80. raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
  81. return {
  82. "username": parsed_url.username or None,
  83. "password": parsed_url.password or None,
  84. "service": parsed_url.hostname or "mymaster",
  85. "port": parsed_url.port or 6379,
  86. "db": int(parsed_url.path.lstrip("/") or 0),
  87. }
  88. def get_redis_connection(
  89. redis_url,
  90. redis_sentinels,
  91. redis_cluster=False,
  92. async_mode=False,
  93. decode_responses=True,
  94. ):
  95. cache_key = (
  96. redis_url,
  97. tuple(redis_sentinels) if redis_sentinels else (),
  98. async_mode,
  99. decode_responses,
  100. )
  101. if cache_key in _CONNECTION_CACHE:
  102. return _CONNECTION_CACHE[cache_key]
  103. connection = None
  104. if async_mode:
  105. import redis.asyncio as redis
  106. # If using sentinel in async mode
  107. if redis_sentinels:
  108. redis_config = parse_redis_service_url(redis_url)
  109. sentinel = redis.sentinel.Sentinel(
  110. redis_sentinels,
  111. port=redis_config["port"],
  112. db=redis_config["db"],
  113. username=redis_config["username"],
  114. password=redis_config["password"],
  115. decode_responses=decode_responses,
  116. )
  117. connection = SentinelRedisProxy(
  118. sentinel,
  119. redis_config["service"],
  120. async_mode=async_mode,
  121. )
  122. elif redis_cluster:
  123. if not redis_url:
  124. raise ValueError("Redis URL must be provided for cluster mode.")
  125. return redis.cluster.RedisCluster.from_url(
  126. redis_url, decode_responses=decode_responses
  127. )
  128. elif redis_url:
  129. connection = redis.from_url(redis_url, decode_responses=decode_responses)
  130. else:
  131. import redis
  132. if redis_sentinels:
  133. redis_config = parse_redis_service_url(redis_url)
  134. sentinel = redis.sentinel.Sentinel(
  135. redis_sentinels,
  136. port=redis_config["port"],
  137. db=redis_config["db"],
  138. username=redis_config["username"],
  139. password=redis_config["password"],
  140. decode_responses=decode_responses,
  141. )
  142. connection = SentinelRedisProxy(
  143. sentinel,
  144. redis_config["service"],
  145. async_mode=async_mode,
  146. )
  147. elif redis_cluster:
  148. if not redis_url:
  149. raise ValueError("Redis URL must be provided for cluster mode.")
  150. return redis.cluster.RedisCluster.from_url(
  151. redis_url, decode_responses=decode_responses
  152. )
  153. elif redis_url:
  154. connection = redis.Redis.from_url(
  155. redis_url, decode_responses=decode_responses
  156. )
  157. _CONNECTION_CACHE[cache_key] = connection
  158. return connection
  159. def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
  160. if sentinel_hosts_env:
  161. sentinel_hosts = sentinel_hosts_env.split(",")
  162. sentinel_port = int(sentinel_port_env)
  163. return [(host, sentinel_port) for host in sentinel_hosts]
  164. return []
  165. def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
  166. redis_config = parse_redis_service_url(redis_url)
  167. username = redis_config["username"] or ""
  168. password = redis_config["password"] or ""
  169. auth_part = ""
  170. if username or password:
  171. auth_part = f"{username}:{password}@"
  172. hosts_part = ",".join(
  173. f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
  174. )
  175. return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"