utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import json
  2. import uuid
  3. from open_webui.utils.redis import get_redis_connection
  4. from open_webui.env import REDIS_KEY_PREFIX
  5. from typing import Optional, List, Tuple
  6. import pycrdt as Y
  7. class RedisLock:
  8. def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
  9. self.lock_name = lock_name
  10. self.lock_id = str(uuid.uuid4())
  11. self.timeout_secs = timeout_secs
  12. self.lock_obtained = False
  13. self.redis = get_redis_connection(
  14. redis_url, redis_sentinels, decode_responses=True
  15. )
  16. def aquire_lock(self):
  17. # nx=True will only set this key if it _hasn't_ already been set
  18. self.lock_obtained = self.redis.set(
  19. self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
  20. )
  21. return self.lock_obtained
  22. def renew_lock(self):
  23. # xx=True will only set this key if it _has_ already been set
  24. return self.redis.set(
  25. self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
  26. )
  27. def release_lock(self):
  28. lock_value = self.redis.get(self.lock_name)
  29. if lock_value and lock_value == self.lock_id:
  30. self.redis.delete(self.lock_name)
  31. class RedisDict:
  32. def __init__(self, name, redis_url, redis_sentinels=[]):
  33. self.name = name
  34. self.redis = get_redis_connection(
  35. redis_url, redis_sentinels, decode_responses=True
  36. )
  37. def __setitem__(self, key, value):
  38. serialized_value = json.dumps(value)
  39. self.redis.hset(self.name, key, serialized_value)
  40. def __getitem__(self, key):
  41. value = self.redis.hget(self.name, key)
  42. if value is None:
  43. raise KeyError(key)
  44. return json.loads(value)
  45. def __delitem__(self, key):
  46. result = self.redis.hdel(self.name, key)
  47. if result == 0:
  48. raise KeyError(key)
  49. def __contains__(self, key):
  50. return self.redis.hexists(self.name, key)
  51. def __len__(self):
  52. return self.redis.hlen(self.name)
  53. def keys(self):
  54. return self.redis.hkeys(self.name)
  55. def values(self):
  56. return [json.loads(v) for v in self.redis.hvals(self.name)]
  57. def items(self):
  58. return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
  59. def get(self, key, default=None):
  60. try:
  61. return self[key]
  62. except KeyError:
  63. return default
  64. def clear(self):
  65. self.redis.delete(self.name)
  66. def update(self, other=None, **kwargs):
  67. if other is not None:
  68. for k, v in other.items() if hasattr(other, "items") else other:
  69. self[k] = v
  70. for k, v in kwargs.items():
  71. self[k] = v
  72. def setdefault(self, key, default=None):
  73. if key not in self:
  74. self[key] = default
  75. return self[key]
  76. class YdocManager:
  77. def __init__(
  78. self,
  79. redis=None,
  80. redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
  81. ):
  82. self._updates = {}
  83. self._users = {}
  84. self._redis = redis
  85. self._redis_key_prefix = redis_key_prefix
  86. async def append_to_updates(self, document_id: str, update: bytes):
  87. document_id = document_id.replace(":", "_")
  88. if self._redis:
  89. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  90. await self._redis.rpush(redis_key, json.dumps(list(update)))
  91. else:
  92. if document_id not in self._updates:
  93. self._updates[document_id] = []
  94. self._updates[document_id].append(update)
  95. async def get_updates(self, document_id: str) -> List[bytes]:
  96. document_id = document_id.replace(":", "_")
  97. if self._redis:
  98. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  99. updates = await self._redis.lrange(redis_key, 0, -1)
  100. return [bytes(json.loads(update)) for update in updates]
  101. else:
  102. return self._updates.get(document_id, [])
  103. async def document_exists(self, document_id: str) -> bool:
  104. document_id = document_id.replace(":", "_")
  105. if self._redis:
  106. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  107. return await self._redis.exists(redis_key) > 0
  108. else:
  109. return document_id in self._updates
  110. async def get_users(self, document_id: str) -> List[str]:
  111. document_id = document_id.replace(":", "_")
  112. if self._redis:
  113. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  114. users = await self._redis.smembers(redis_key)
  115. return list(users)
  116. else:
  117. return self._users.get(document_id, [])
  118. async def add_user(self, document_id: str, user_id: str):
  119. document_id = document_id.replace(":", "_")
  120. if self._redis:
  121. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  122. await self._redis.sadd(redis_key, user_id)
  123. else:
  124. if document_id not in self._users:
  125. self._users[document_id] = set()
  126. self._users[document_id].add(user_id)
  127. async def remove_user(self, document_id: str, user_id: str):
  128. document_id = document_id.replace(":", "_")
  129. if self._redis:
  130. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  131. await self._redis.srem(redis_key, user_id)
  132. else:
  133. if document_id in self._users and user_id in self._users[document_id]:
  134. self._users[document_id].remove(user_id)
  135. async def remove_user_from_all_documents(self, user_id: str):
  136. if self._redis:
  137. keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
  138. for key in keys:
  139. if key.endswith(":users"):
  140. await self._redis.srem(key, user_id)
  141. document_id = key.split(":")[-2]
  142. if len(await self.get_users(document_id)) == 0:
  143. await self.clear_document(document_id)
  144. else:
  145. for document_id in list(self._users.keys()):
  146. if user_id in self._users[document_id]:
  147. self._users[document_id].remove(user_id)
  148. if not self._users[document_id]:
  149. del self._users[document_id]
  150. await self.clear_document(document_id)
  151. async def clear_document(self, document_id: str):
  152. document_id = document_id.replace(":", "_")
  153. if self._redis:
  154. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  155. await self._redis.delete(redis_key)
  156. redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
  157. await self._redis.delete(redis_users_key)
  158. else:
  159. if document_id in self._updates:
  160. del self._updates[document_id]
  161. if document_id in self._users:
  162. del self._users[document_id]