import json import uuid from open_webui.utils.redis import get_redis_connection from open_webui.env import REDIS_KEY_PREFIX from typing import Optional, List, Tuple import pycrdt as Y class RedisLock: def __init__( self, redis_url, lock_name, timeout_secs, redis_sentinels=[], redis_cluster=False, ): self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False self.redis = get_redis_connection( redis_url, redis_sentinels, redis_cluster=redis_cluster, decode_responses=True, ) def aquire_lock(self): # nx=True will only set this key if it _hasn't_ already been set self.lock_obtained = self.redis.set( self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs ) return self.lock_obtained def renew_lock(self): # xx=True will only set this key if it _has_ already been set return self.redis.set( self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs ) def release_lock(self): lock_value = self.redis.get(self.lock_name) if lock_value and lock_value == self.lock_id: self.redis.delete(self.lock_name) class RedisDict: def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False): self.name = name self.redis = get_redis_connection( redis_url, redis_sentinels, redis_cluster=redis_cluster, decode_responses=True, ) def __setitem__(self, key, value): serialized_value = json.dumps(value) self.redis.hset(self.name, key, serialized_value) def __getitem__(self, key): value = self.redis.hget(self.name, key) if value is None: raise KeyError(key) return json.loads(value) def __delitem__(self, key): result = self.redis.hdel(self.name, key) if result == 0: raise KeyError(key) def __contains__(self, key): return self.redis.hexists(self.name, key) def __len__(self): return self.redis.hlen(self.name) def keys(self): return self.redis.hkeys(self.name) def values(self): return [json.loads(v) for v in self.redis.hvals(self.name)] def items(self): return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] def get(self, key, default=None): try: return self[key] except KeyError: return default def clear(self): self.redis.delete(self.name) def update(self, other=None, **kwargs): if other is not None: for k, v in other.items() if hasattr(other, "items") else other: self[k] = v for k, v in kwargs.items(): self[k] = v def setdefault(self, key, default=None): if key not in self: self[key] = default return self[key] class YdocManager: def __init__( self, redis=None, redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents", ): self._updates = {} self._users = {} self._redis = redis self._redis_key_prefix = redis_key_prefix async def append_to_updates(self, document_id: str, update: bytes): document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:updates" await self._redis.rpush(redis_key, json.dumps(list(update))) else: if document_id not in self._updates: self._updates[document_id] = [] self._updates[document_id].append(update) async def get_updates(self, document_id: str) -> List[bytes]: document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:updates" updates = await self._redis.lrange(redis_key, 0, -1) return [bytes(json.loads(update)) for update in updates] else: return self._updates.get(document_id, []) async def document_exists(self, document_id: str) -> bool: document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:updates" return await self._redis.exists(redis_key) > 0 else: return document_id in self._updates async def get_users(self, document_id: str) -> List[str]: document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:users" users = await self._redis.smembers(redis_key) return list(users) else: return self._users.get(document_id, []) async def add_user(self, document_id: str, user_id: str): document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:users" await self._redis.sadd(redis_key, user_id) else: if document_id not in self._users: self._users[document_id] = set() self._users[document_id].add(user_id) async def remove_user(self, document_id: str, user_id: str): document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:users" await self._redis.srem(redis_key, user_id) else: if document_id in self._users and user_id in self._users[document_id]: self._users[document_id].remove(user_id) async def remove_user_from_all_documents(self, user_id: str): if self._redis: keys = await self._redis.keys(f"{self._redis_key_prefix}:*") for key in keys: if key.endswith(":users"): await self._redis.srem(key, user_id) document_id = key.split(":")[-2] if len(await self.get_users(document_id)) == 0: await self.clear_document(document_id) else: for document_id in list(self._users.keys()): if user_id in self._users[document_id]: self._users[document_id].remove(user_id) if not self._users[document_id]: del self._users[document_id] await self.clear_document(document_id) async def clear_document(self, document_id: str): document_id = document_id.replace(":", "_") if self._redis: redis_key = f"{self._redis_key_prefix}:{document_id}:updates" await self._redis.delete(redis_key) redis_users_key = f"{self._redis_key_prefix}:{document_id}:users" await self._redis.delete(redis_users_key) else: if document_id in self._updates: del self._updates[document_id] if document_id in self._users: del self._users[document_id]