|
@@ -4,16 +4,88 @@ from typing import Dict
|
|
|
from uuid import uuid4
|
|
|
import json
|
|
|
from redis import Redis
|
|
|
+from fastapi import Request
|
|
|
+from typing import Dict, List, Optional
|
|
|
|
|
|
# A dictionary to keep track of active tasks
|
|
|
tasks: Dict[str, asyncio.Task] = {}
|
|
|
chat_tasks = {}
|
|
|
|
|
|
|
|
|
+REDIS_TASKS_KEY = "open-webui:tasks"
|
|
|
+REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
|
|
|
+REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
|
|
|
+
|
|
|
+
|
|
|
+def is_redis(request: Request) -> bool:
|
|
|
+ # Called everywhere a request is available to check Redis
|
|
|
+ return hasattr(request.app.state, "redis") and isinstance(
|
|
|
+ request.app.state.redis, Redis
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def redis_task_command_listener(app):
|
|
|
+ redis: Redis = app.state.redis
|
|
|
+ pubsub = redis.pubsub()
|
|
|
+ await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
|
|
|
+ print("Subscribed to Redis task command channel")
|
|
|
+
|
|
|
+ async for message in pubsub.listen():
|
|
|
+ if message["type"] != "message":
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ command = json.loads(message["data"])
|
|
|
+ if command.get("action") == "stop":
|
|
|
+ task_id = command.get("task_id")
|
|
|
+ local_task = tasks.get(task_id)
|
|
|
+ if local_task:
|
|
|
+ local_task.cancel()
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error handling distributed task command: {e}")
|
|
|
+
|
|
|
+
|
|
|
+### ------------------------------
|
|
|
+### REDIS-ENABLED HANDLERS
|
|
|
+### ------------------------------
|
|
|
+
|
|
|
+
|
|
|
+def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
|
|
+ pipe = redis.pipeline()
|
|
|
+ pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
|
|
|
+ if chat_id:
|
|
|
+ pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
|
|
+ pipe.execute()
|
|
|
+
|
|
|
+
|
|
|
+def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
|
|
+ pipe = redis.pipeline()
|
|
|
+ pipe.hdel(REDIS_TASKS_KEY, task_id)
|
|
|
+ if chat_id:
|
|
|
+ pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
|
|
+ if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
|
|
|
+ pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
|
|
|
+ pipe.execute()
|
|
|
+
|
|
|
+
|
|
|
+def redis_list_tasks(redis: Redis) -> List[str]:
|
|
|
+ return list(redis.hkeys(REDIS_TASKS_KEY))
|
|
|
+
|
|
|
+
|
|
|
+def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
|
|
|
+ return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
|
|
|
+
|
|
|
+
|
|
|
+def redis_send_command(redis: Redis, command: dict):
|
|
|
+ redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
|
|
|
+
|
|
|
+
|
|
|
def cleanup_task(request, task_id: str, id=None):
|
|
|
"""
|
|
|
Remove a completed or canceled task from the global `tasks` dictionary.
|
|
|
"""
|
|
|
+ if is_redis(request):
|
|
|
+ redis_cleanup_task(request.app.state.redis, task_id, id)
|
|
|
+
|
|
|
tasks.pop(task_id, None) # Remove the task if it exists
|
|
|
|
|
|
# If an ID is provided, remove the task from the chat_tasks dictionary
|
|
@@ -40,6 +112,9 @@ def create_task(request, coroutine, id=None):
|
|
|
else:
|
|
|
chat_tasks[id] = [task_id]
|
|
|
|
|
|
+ if is_redis(request):
|
|
|
+ redis_save_task(request.app.state.redis, task_id, id)
|
|
|
+
|
|
|
return task_id, task
|
|
|
|
|
|
|
|
@@ -47,6 +122,8 @@ def list_tasks(request):
|
|
|
"""
|
|
|
List all currently active task IDs.
|
|
|
"""
|
|
|
+ if is_redis(request):
|
|
|
+ return redis_list_tasks(request.app.state.redis)
|
|
|
return list(tasks.keys())
|
|
|
|
|
|
|
|
@@ -54,6 +131,8 @@ def list_task_ids_by_chat_id(request, id):
|
|
|
"""
|
|
|
List all tasks associated with a specific ID.
|
|
|
"""
|
|
|
+ if is_redis(request):
|
|
|
+ return redis_list_chat_tasks(request.app.state.redis, id)
|
|
|
return chat_tasks.get(id, [])
|
|
|
|
|
|
|
|
@@ -61,6 +140,18 @@ async def stop_task(request, task_id: str):
|
|
|
"""
|
|
|
Cancel a running task and remove it from the global task list.
|
|
|
"""
|
|
|
+ if is_redis(request):
|
|
|
+ # PUBSUB: All instances check if they have this task, and stop if so.
|
|
|
+ redis_send_command(
|
|
|
+ request.app.state.redis,
|
|
|
+ {
|
|
|
+ "action": "stop",
|
|
|
+ "task_id": task_id,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ # Optionally check if task_id still in Redis a few moments later for feedback?
|
|
|
+ return {"status": True, "message": f"Stop signal sent for {task_id}"}
|
|
|
+
|
|
|
task = tasks.get(task_id)
|
|
|
if not task:
|
|
|
raise ValueError(f"Task with ID {task_id} not found.")
|