redis.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import inspect
  2. from urllib.parse import urlparse
  3. import redis
  4. from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
  5. class SentinelRedisProxy:
  6. def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
  7. self._sentinel = sentinel
  8. self._service = service
  9. self._kw = kw
  10. self._async_mode = async_mode
  11. def _master(self):
  12. return self._sentinel.master_for(self._service, **self._kw)
  13. def __getattr__(self, item):
  14. master = self._master()
  15. orig_attr = getattr(master, item)
  16. if not callable(orig_attr):
  17. return orig_attr
  18. FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
  19. if item in FACTORY_METHODS:
  20. return orig_attr
  21. if self._async_mode:
  22. async def _wrapped(*args, **kwargs):
  23. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  24. try:
  25. method = getattr(self._master(), item)
  26. result = method(*args, **kwargs)
  27. if inspect.iscoroutine(result):
  28. return await result
  29. return result
  30. except (
  31. redis.exceptions.ConnectionError,
  32. redis.exceptions.ReadOnlyError,
  33. ) as e:
  34. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  35. continue
  36. raise e from e
  37. return _wrapped
  38. else:
  39. def _wrapped(*args, **kwargs):
  40. for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
  41. try:
  42. method = getattr(self._master(), item)
  43. return method(*args, **kwargs)
  44. except (
  45. redis.exceptions.ConnectionError,
  46. redis.exceptions.ReadOnlyError,
  47. ) as e:
  48. if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
  49. continue
  50. raise e from e
  51. return _wrapped
  52. def parse_redis_service_url(redis_url):
  53. parsed_url = urlparse(redis_url)
  54. if parsed_url.scheme != "redis":
  55. raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
  56. return {
  57. "username": parsed_url.username or None,
  58. "password": parsed_url.password or None,
  59. "service": parsed_url.hostname or "mymaster",
  60. "port": parsed_url.port or 6379,
  61. "db": int(parsed_url.path.lstrip("/") or 0),
  62. }
  63. def get_redis_connection(
  64. redis_url, redis_sentinels, async_mode=False, decode_responses=True
  65. ):
  66. if async_mode:
  67. import redis.asyncio as redis
  68. # If using sentinel in async mode
  69. if redis_sentinels:
  70. redis_config = parse_redis_service_url(redis_url)
  71. sentinel = redis.sentinel.Sentinel(
  72. redis_sentinels,
  73. port=redis_config["port"],
  74. db=redis_config["db"],
  75. username=redis_config["username"],
  76. password=redis_config["password"],
  77. decode_responses=decode_responses,
  78. )
  79. return SentinelRedisProxy(
  80. sentinel,
  81. redis_config["service"],
  82. async_mode=async_mode,
  83. )
  84. elif redis_url:
  85. return redis.from_url(redis_url, decode_responses=decode_responses)
  86. else:
  87. return None
  88. else:
  89. import redis
  90. if redis_sentinels:
  91. redis_config = parse_redis_service_url(redis_url)
  92. sentinel = redis.sentinel.Sentinel(
  93. redis_sentinels,
  94. port=redis_config["port"],
  95. db=redis_config["db"],
  96. username=redis_config["username"],
  97. password=redis_config["password"],
  98. decode_responses=decode_responses,
  99. )
  100. return SentinelRedisProxy(
  101. sentinel,
  102. redis_config["service"],
  103. async_mode=async_mode,
  104. )
  105. elif redis_url:
  106. return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
  107. else:
  108. return None
  109. def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
  110. if sentinel_hosts_env:
  111. sentinel_hosts = sentinel_hosts_env.split(",")
  112. sentinel_port = int(sentinel_port_env)
  113. return [(host, sentinel_port) for host in sentinel_hosts]
  114. return []
  115. def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
  116. redis_config = parse_redis_service_url(redis_url)
  117. username = redis_config["username"] or ""
  118. password = redis_config["password"] or ""
  119. auth_part = ""
  120. if username or password:
  121. auth_part = f"{username}:{password}@"
  122. hosts_part = ",".join(
  123. f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
  124. )
  125. return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"