Browse Source

refac/fix: multi-replica stop task (response)

Timothy Jaeryang Baek 4 months ago
parent
commit
d8d8380a78
2 changed files with 102 additions and 0 deletions
  1. 11 0
      backend/open_webui/main.py
  2. 91 0
      backend/open_webui/tasks.py

+ 11 - 0
backend/open_webui/main.py

@@ -10,6 +10,7 @@ import time
 import random
 from uuid import uuid4
 
+
 from contextlib import asynccontextmanager
 from urllib.parse import urlencode, parse_qs, urlparse
 from pydantic import BaseModel
@@ -20,6 +21,7 @@ from aiocache import cached
 import aiohttp
 import anyio.to_thread
 import requests
+from redis import Redis
 
 
 from fastapi import (
@@ -436,6 +438,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
 from open_webui.utils.redis import get_redis_connection
 
 from open_webui.tasks import (
+    redis_task_command_listener,
     list_task_ids_by_chat_id,
     stop_task,
     list_tasks,
@@ -508,6 +511,11 @@ async def lifespan(app: FastAPI):
         ),
     )
 
+    if isinstance(app.state.redis, Redis):
+        app.state.redis_task_command_listener = asyncio.create_task(
+            redis_task_command_listener(app)
+        )
+
     if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
         limiter = anyio.to_thread.current_default_thread_limiter()
         limiter.total_tokens = THREAD_POOL_SIZE
@@ -516,6 +524,9 @@ async def lifespan(app: FastAPI):
 
     yield
 
+    if hasattr(app.state, "redis_task_command_listener"):
+        app.state.redis_task_command_listener.cancel()
+
 
 app = FastAPI(
     title="Open WebUI",

+ 91 - 0
backend/open_webui/tasks.py

@@ -4,16 +4,88 @@ from typing import Dict
 from uuid import uuid4
 import json
 from redis import Redis
+from fastapi import Request
+from typing import Dict, List, Optional
 
 # A dictionary to keep track of active tasks
 tasks: Dict[str, asyncio.Task] = {}
 chat_tasks = {}
 
 
+REDIS_TASKS_KEY = "open-webui:tasks"
+REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
+REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
+
+
+def is_redis(request: Request) -> bool:
+    # Called everywhere a request is available to check Redis
+    return hasattr(request.app.state, "redis") and isinstance(
+        request.app.state.redis, Redis
+    )
+
+
+async def redis_task_command_listener(app):
+    redis: Redis = app.state.redis
+    pubsub = redis.pubsub()
+    await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
+    print("Subscribed to Redis task command channel")
+
+    async for message in pubsub.listen():
+        if message["type"] != "message":
+            continue
+        try:
+            command = json.loads(message["data"])
+            if command.get("action") == "stop":
+                task_id = command.get("task_id")
+                local_task = tasks.get(task_id)
+                if local_task:
+                    local_task.cancel()
+        except Exception as e:
+            print(f"Error handling distributed task command: {e}")
+
+
+### ------------------------------
+### REDIS-ENABLED HANDLERS
+### ------------------------------
+
+
+def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+    pipe = redis.pipeline()
+    pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
+    if chat_id:
+        pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+    pipe.execute()
+
+
+def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+    pipe = redis.pipeline()
+    pipe.hdel(REDIS_TASKS_KEY, task_id)
+    if chat_id:
+        pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+        if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
+            pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")  # Remove if empty set
+    pipe.execute()
+
+
+def redis_list_tasks(redis: Redis) -> List[str]:
+    return list(redis.hkeys(REDIS_TASKS_KEY))
+
+
+def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
+    return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
+
+
+def redis_send_command(redis: Redis, command: dict):
+    redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
+
+
 def cleanup_task(request, task_id: str, id=None):
     """
     Remove a completed or canceled task from the global `tasks` dictionary.
     """
+    if is_redis(request):
+        redis_cleanup_task(request.app.state.redis, task_id, id)
+
     tasks.pop(task_id, None)  # Remove the task if it exists
 
     # If an ID is provided, remove the task from the chat_tasks dictionary
@@ -40,6 +112,9 @@ def create_task(request, coroutine, id=None):
     else:
         chat_tasks[id] = [task_id]
 
+    if is_redis(request):
+        redis_save_task(request.app.state.redis, task_id, id)
+
     return task_id, task
 
 
@@ -47,6 +122,8 @@ def list_tasks(request):
     """
     List all currently active task IDs.
     """
+    if is_redis(request):
+        return redis_list_tasks(request.app.state.redis)
     return list(tasks.keys())
 
 
@@ -54,6 +131,8 @@ def list_task_ids_by_chat_id(request, id):
     """
     List all tasks associated with a specific ID.
     """
+    if is_redis(request):
+        return redis_list_chat_tasks(request.app.state.redis, id)
     return chat_tasks.get(id, [])
 
 
@@ -61,6 +140,18 @@ async def stop_task(request, task_id: str):
     """
     Cancel a running task and remove it from the global task list.
     """
+    if is_redis(request):
+        # PUBSUB: All instances check if they have this task, and stop if so.
+        redis_send_command(
+            request.app.state.redis,
+            {
+                "action": "stop",
+                "task_id": task_id,
+            },
+        )
+        # Optionally check if task_id still in Redis a few moments later for feedback?
+        return {"status": True, "message": f"Stop signal sent for {task_id}"}
+
     task = tasks.get(task_id)
     if not task:
         raise ValueError(f"Task with ID {task_id} not found.")