utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import json
  2. import uuid
  3. from open_webui.utils.redis import get_redis_connection
  4. from open_webui.env import REDIS_KEY_PREFIX
  5. from typing import Optional, List, Tuple
  6. import pycrdt as Y
  7. class RedisLock:
  8. def __init__(
  9. self,
  10. redis_url,
  11. lock_name,
  12. timeout_secs,
  13. redis_sentinels=[],
  14. redis_cluster=False,
  15. ):
  16. self.lock_name = lock_name
  17. self.lock_id = str(uuid.uuid4())
  18. self.timeout_secs = timeout_secs
  19. self.lock_obtained = False
  20. self.redis = get_redis_connection(
  21. redis_url,
  22. redis_sentinels,
  23. redis_cluster=redis_cluster,
  24. decode_responses=True,
  25. )
  26. def aquire_lock(self):
  27. # nx=True will only set this key if it _hasn't_ already been set
  28. self.lock_obtained = self.redis.set(
  29. self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
  30. )
  31. return self.lock_obtained
  32. def renew_lock(self):
  33. # xx=True will only set this key if it _has_ already been set
  34. return self.redis.set(
  35. self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
  36. )
  37. def release_lock(self):
  38. lock_value = self.redis.get(self.lock_name)
  39. if lock_value and lock_value == self.lock_id:
  40. self.redis.delete(self.lock_name)
  41. class RedisDict:
  42. def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
  43. self.name = name
  44. self.redis = get_redis_connection(
  45. redis_url,
  46. redis_sentinels,
  47. redis_cluster=redis_cluster,
  48. decode_responses=True,
  49. )
  50. def __setitem__(self, key, value):
  51. serialized_value = json.dumps(value)
  52. self.redis.hset(self.name, key, serialized_value)
  53. def __getitem__(self, key):
  54. value = self.redis.hget(self.name, key)
  55. if value is None:
  56. raise KeyError(key)
  57. return json.loads(value)
  58. def __delitem__(self, key):
  59. result = self.redis.hdel(self.name, key)
  60. if result == 0:
  61. raise KeyError(key)
  62. def __contains__(self, key):
  63. return self.redis.hexists(self.name, key)
  64. def __len__(self):
  65. return self.redis.hlen(self.name)
  66. def keys(self):
  67. return self.redis.hkeys(self.name)
  68. def values(self):
  69. return [json.loads(v) for v in self.redis.hvals(self.name)]
  70. def items(self):
  71. return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
  72. def get(self, key, default=None):
  73. try:
  74. return self[key]
  75. except KeyError:
  76. return default
  77. def clear(self):
  78. self.redis.delete(self.name)
  79. def update(self, other=None, **kwargs):
  80. if other is not None:
  81. for k, v in other.items() if hasattr(other, "items") else other:
  82. self[k] = v
  83. for k, v in kwargs.items():
  84. self[k] = v
  85. def setdefault(self, key, default=None):
  86. if key not in self:
  87. self[key] = default
  88. return self[key]
  89. class YdocManager:
  90. def __init__(
  91. self,
  92. redis=None,
  93. redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
  94. ):
  95. self._updates = {}
  96. self._users = {}
  97. self._redis = redis
  98. self._redis_key_prefix = redis_key_prefix
  99. async def append_to_updates(self, document_id: str, update: bytes):
  100. document_id = document_id.replace(":", "_")
  101. if self._redis:
  102. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  103. await self._redis.rpush(redis_key, json.dumps(list(update)))
  104. else:
  105. if document_id not in self._updates:
  106. self._updates[document_id] = []
  107. self._updates[document_id].append(update)
  108. async def get_updates(self, document_id: str) -> List[bytes]:
  109. document_id = document_id.replace(":", "_")
  110. if self._redis:
  111. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  112. updates = await self._redis.lrange(redis_key, 0, -1)
  113. return [bytes(json.loads(update)) for update in updates]
  114. else:
  115. return self._updates.get(document_id, [])
  116. async def document_exists(self, document_id: str) -> bool:
  117. document_id = document_id.replace(":", "_")
  118. if self._redis:
  119. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  120. return await self._redis.exists(redis_key) > 0
  121. else:
  122. return document_id in self._updates
  123. async def get_users(self, document_id: str) -> List[str]:
  124. document_id = document_id.replace(":", "_")
  125. if self._redis:
  126. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  127. users = await self._redis.smembers(redis_key)
  128. return list(users)
  129. else:
  130. return self._users.get(document_id, [])
  131. async def add_user(self, document_id: str, user_id: str):
  132. document_id = document_id.replace(":", "_")
  133. if self._redis:
  134. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  135. await self._redis.sadd(redis_key, user_id)
  136. else:
  137. if document_id not in self._users:
  138. self._users[document_id] = set()
  139. self._users[document_id].add(user_id)
  140. async def remove_user(self, document_id: str, user_id: str):
  141. document_id = document_id.replace(":", "_")
  142. if self._redis:
  143. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  144. await self._redis.srem(redis_key, user_id)
  145. else:
  146. if document_id in self._users and user_id in self._users[document_id]:
  147. self._users[document_id].remove(user_id)
  148. async def remove_user_from_all_documents(self, user_id: str):
  149. if self._redis:
  150. keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
  151. for key in keys:
  152. if key.endswith(":users"):
  153. await self._redis.srem(key, user_id)
  154. document_id = key.split(":")[-2]
  155. if len(await self.get_users(document_id)) == 0:
  156. await self.clear_document(document_id)
  157. else:
  158. for document_id in list(self._users.keys()):
  159. if user_id in self._users[document_id]:
  160. self._users[document_id].remove(user_id)
  161. if not self._users[document_id]:
  162. del self._users[document_id]
  163. await self.clear_document(document_id)
  164. async def clear_document(self, document_id: str):
  165. document_id = document_id.replace(":", "_")
  166. if self._redis:
  167. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  168. await self._redis.delete(redis_key)
  169. redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
  170. await self._redis.delete(redis_users_key)
  171. else:
  172. if document_id in self._updates:
  173. del self._updates[document_id]
  174. if document_id in self._users:
  175. del self._users[document_id]