1
0
Эх сурвалжийг харах

refac/fix: multi-replica tasks

Timothy Jaeryang Baek 4 сар өмнө
parent
commit
ea8dc333ee

+ 3 - 3
backend/open_webui/main.py

@@ -513,7 +513,7 @@ async def lifespan(app: FastAPI):
         async_mode=True,
     )
 
-    if isinstance(app.state.redis, Redis):
+    if app.state.redis is not None:
         app.state.redis_task_command_listener = asyncio.create_task(
             redis_task_command_listener(app)
         )
@@ -1424,7 +1424,7 @@ async def stop_task_endpoint(
 
 @app.get("/api/tasks")
 async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
-    return {"tasks": list_tasks(request)}
+    return {"tasks": await list_tasks(request)}
 
 
 @app.get("/api/tasks/chat/{chat_id}")
@@ -1435,7 +1435,7 @@ async def list_tasks_by_chat_id_endpoint(
     if chat is None or chat.user_id != user.id:
         return {"task_ids": []}
 
-    task_ids = list_task_ids_by_chat_id(request, chat_id)
+    task_ids = await list_task_ids_by_chat_id(request, chat_id)
 
     print(f"Task IDs for chat {chat_id}: {task_ids}")
     return {"task_ids": task_ids}

+ 26 - 26
backend/open_webui/tasks.py

@@ -3,7 +3,7 @@ import asyncio
 from typing import Dict
 from uuid import uuid4
 import json
-from redis import Redis
+from redis.asyncio import Redis
 from fastapi import Request
 from typing import Dict, List, Optional
 
@@ -19,18 +19,16 @@ REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
 
 def is_redis(request: Request) -> bool:
     # Called everywhere a request is available to check Redis
-    return hasattr(request.app.state, "redis") and isinstance(
-        request.app.state.redis, Redis
-    )
+    return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
 
 
 async def redis_task_command_listener(app):
     redis: Redis = app.state.redis
     pubsub = redis.pubsub()
     await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
-    print("Subscribed to Redis task command channel")
 
     async for message in pubsub.listen():
+        print(f"Received message: {message}")
         if message["type"] != "message":
             continue
         try:
@@ -49,42 +47,42 @@ async def redis_task_command_listener(app):
 ### ------------------------------
 
 
-def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
     pipe = redis.pipeline()
     pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
     if chat_id:
         pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
-    pipe.execute()
+    await pipe.execute()
 
 
-def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
     pipe = redis.pipeline()
     pipe.hdel(REDIS_TASKS_KEY, task_id)
     if chat_id:
         pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
-        if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
+        if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
             pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")  # Remove if empty set
-    pipe.execute()
+    await pipe.execute()
 
 
-def redis_list_tasks(redis: Redis) -> List[str]:
-    return list(redis.hkeys(REDIS_TASKS_KEY))
+async def redis_list_tasks(redis: Redis) -> List[str]:
+    return list(await redis.hkeys(REDIS_TASKS_KEY))
 
 
-def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
-    return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
+async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
+    return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
 
 
-def redis_send_command(redis: Redis, command: dict):
-    redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
+async def redis_send_command(redis: Redis, command: dict):
+    await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
 
 
-def cleanup_task(request, task_id: str, id=None):
+async def cleanup_task(request, task_id: str, id=None):
     """
     Remove a completed or canceled task from the global `tasks` dictionary.
     """
     if is_redis(request):
-        redis_cleanup_task(request.app.state.redis, task_id, id)
+        await redis_cleanup_task(request.app.state.redis, task_id, id)
 
     tasks.pop(task_id, None)  # Remove the task if it exists
 
@@ -95,7 +93,7 @@ def cleanup_task(request, task_id: str, id=None):
             chat_tasks.pop(id, None)
 
 
-def create_task(request, coroutine, id=None):
+async def create_task(request, coroutine, id=None):
     """
     Create a new asyncio task and add it to the global task dictionary.
     """
@@ -103,7 +101,9 @@ def create_task(request, coroutine, id=None):
     task = asyncio.create_task(coroutine)  # Create the task
 
     # Add a done callback for cleanup
-    task.add_done_callback(lambda t: cleanup_task(request, task_id, id))
+    task.add_done_callback(
+        lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
+    )
     tasks[task_id] = task
 
     # If an ID is provided, associate the task with that ID
@@ -113,26 +113,26 @@ def create_task(request, coroutine, id=None):
         chat_tasks[id] = [task_id]
 
     if is_redis(request):
-        redis_save_task(request.app.state.redis, task_id, id)
+        await redis_save_task(request.app.state.redis, task_id, id)
 
     return task_id, task
 
 
-def list_tasks(request):
+async def list_tasks(request):
     """
     List all currently active task IDs.
     """
     if is_redis(request):
-        return redis_list_tasks(request.app.state.redis)
+        return await redis_list_tasks(request.app.state.redis)
     return list(tasks.keys())
 
 
-def list_task_ids_by_chat_id(request, id):
+async def list_task_ids_by_chat_id(request, id):
     """
     List all tasks associated with a specific ID.
     """
     if is_redis(request):
-        return redis_list_chat_tasks(request.app.state.redis, id)
+        return await redis_list_chat_tasks(request.app.state.redis, id)
     return chat_tasks.get(id, [])
 
 
@@ -142,7 +142,7 @@ async def stop_task(request, task_id: str):
     """
     if is_redis(request):
         # PUBSUB: All instances check if they have this task, and stop if so.
-        redis_send_command(
+        await redis_send_command(
             request.app.state.redis,
             {
                 "action": "stop",

+ 1 - 1
backend/open_webui/utils/middleware.py

@@ -2413,7 +2413,7 @@ async def process_chat_response(
                 await response.background()
 
         # background_tasks.add_task(post_response_handler, response, events)
-        task_id, _ = create_task(
+        task_id, _ = await create_task(
             request, post_response_handler(response, events), id=metadata["chat_id"]
         )
         return {"status": True, "task_id": task_id}