utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import json
  2. import uuid
  3. from open_webui.utils.redis import get_redis_connection
  4. from typing import Optional, List, Tuple
  5. import pycrdt as Y
  6. class RedisLock:
  7. def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
  8. self.lock_name = lock_name
  9. self.lock_id = str(uuid.uuid4())
  10. self.timeout_secs = timeout_secs
  11. self.lock_obtained = False
  12. self.redis = get_redis_connection(
  13. redis_url, redis_sentinels, decode_responses=True
  14. )
  15. def aquire_lock(self):
  16. # nx=True will only set this key if it _hasn't_ already been set
  17. self.lock_obtained = self.redis.set(
  18. self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
  19. )
  20. return self.lock_obtained
  21. def renew_lock(self):
  22. # xx=True will only set this key if it _has_ already been set
  23. return self.redis.set(
  24. self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
  25. )
  26. def release_lock(self):
  27. lock_value = self.redis.get(self.lock_name)
  28. if lock_value and lock_value == self.lock_id:
  29. self.redis.delete(self.lock_name)
  30. class RedisDict:
  31. def __init__(self, name, redis_url, redis_sentinels=[]):
  32. self.name = name
  33. self.redis = get_redis_connection(
  34. redis_url, redis_sentinels, decode_responses=True
  35. )
  36. def __setitem__(self, key, value):
  37. serialized_value = json.dumps(value)
  38. self.redis.hset(self.name, key, serialized_value)
  39. def __getitem__(self, key):
  40. value = self.redis.hget(self.name, key)
  41. if value is None:
  42. raise KeyError(key)
  43. return json.loads(value)
  44. def __delitem__(self, key):
  45. result = self.redis.hdel(self.name, key)
  46. if result == 0:
  47. raise KeyError(key)
  48. def __contains__(self, key):
  49. return self.redis.hexists(self.name, key)
  50. def __len__(self):
  51. return self.redis.hlen(self.name)
  52. def keys(self):
  53. return self.redis.hkeys(self.name)
  54. def values(self):
  55. return [json.loads(v) for v in self.redis.hvals(self.name)]
  56. def items(self):
  57. return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
  58. def get(self, key, default=None):
  59. try:
  60. return self[key]
  61. except KeyError:
  62. return default
  63. def clear(self):
  64. self.redis.delete(self.name)
  65. def update(self, other=None, **kwargs):
  66. if other is not None:
  67. for k, v in other.items() if hasattr(other, "items") else other:
  68. self[k] = v
  69. for k, v in kwargs.items():
  70. self[k] = v
  71. def setdefault(self, key, default=None):
  72. if key not in self:
  73. self[key] = default
  74. return self[key]
  75. class YdocManager:
  76. def __init__(
  77. self,
  78. redis=None,
  79. redis_key_prefix: str = "open-webui:ydoc:documents",
  80. ):
  81. self._updates = {}
  82. self._users = {}
  83. self._redis = redis
  84. self._redis_key_prefix = redis_key_prefix
  85. async def append_to_updates(self, document_id: str, update: bytes):
  86. document_id = document_id.replace(":", "_")
  87. if self._redis:
  88. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  89. await self._redis.rpush(redis_key, json.dumps(list(update)))
  90. else:
  91. if document_id not in self._updates:
  92. self._updates[document_id] = []
  93. self._updates[document_id].append(update)
  94. async def get_updates(self, document_id: str) -> List[bytes]:
  95. document_id = document_id.replace(":", "_")
  96. if self._redis:
  97. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  98. updates = await self._redis.lrange(redis_key, 0, -1)
  99. return [bytes(json.loads(update)) for update in updates]
  100. else:
  101. return self._updates.get(document_id, [])
  102. async def document_exists(self, document_id: str) -> bool:
  103. document_id = document_id.replace(":", "_")
  104. if self._redis:
  105. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  106. return await self._redis.exists(redis_key) > 0
  107. else:
  108. return document_id in self._updates
  109. async def get_users(self, document_id: str) -> List[str]:
  110. document_id = document_id.replace(":", "_")
  111. if self._redis:
  112. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  113. users = await self._redis.smembers(redis_key)
  114. return list(users)
  115. else:
  116. return self._users.get(document_id, [])
  117. async def add_user(self, document_id: str, user_id: str):
  118. document_id = document_id.replace(":", "_")
  119. if self._redis:
  120. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  121. await self._redis.sadd(redis_key, user_id)
  122. else:
  123. if document_id not in self._users:
  124. self._users[document_id] = set()
  125. self._users[document_id].add(user_id)
  126. async def remove_user(self, document_id: str, user_id: str):
  127. document_id = document_id.replace(":", "_")
  128. if self._redis:
  129. redis_key = f"{self._redis_key_prefix}:{document_id}:users"
  130. await self._redis.srem(redis_key, user_id)
  131. else:
  132. if document_id in self._users and user_id in self._users[document_id]:
  133. self._users[document_id].remove(user_id)
  134. async def remove_user_from_all_documents(self, user_id: str):
  135. if self._redis:
  136. keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
  137. for key in keys:
  138. if key.endswith(":users"):
  139. await self._redis.srem(key, user_id)
  140. document_id = key.split(":")[-2]
  141. if len(await self.get_users(document_id)) == 0:
  142. await self.clear_document(document_id)
  143. else:
  144. for document_id in list(self._users.keys()):
  145. if user_id in self._users[document_id]:
  146. self._users[document_id].remove(user_id)
  147. if not self._users[document_id]:
  148. del self._users[document_id]
  149. await self.clear_document(document_id)
  150. async def clear_document(self, document_id: str):
  151. document_id = document_id.replace(":", "_")
  152. if self._redis:
  153. redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
  154. await self._redis.delete(redis_key)
  155. redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
  156. await self._redis.delete(redis_users_key)
  157. else:
  158. if document_id in self._updates:
  159. del self._updates[document_id]
  160. if document_id in self._users:
  161. del self._users[document_id]