Timothy Jaeryang Baek 2 maanden geleden
bovenliggende
commit
788e7d0487
2 gewijzigde bestanden met toevoegingen van 19 en 24 verwijderingen
  1. 3 3
      backend/open_webui/main.py
  2. 16 21
      backend/open_webui/tasks.py

+ 3 - 3
backend/open_webui/main.py

@@ -1486,7 +1486,7 @@ async def stop_task_endpoint(
     request: Request, task_id: str, user=Depends(get_verified_user)
 ):
     try:
-        result = await stop_task(request, task_id)
+        result = await stop_task(request.app.state.redis, task_id)
         return result
     except ValueError as e:
         raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -1494,7 +1494,7 @@ async def stop_task_endpoint(
 
 @app.get("/api/tasks")
 async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
-    return {"tasks": await list_tasks(request)}
+    return {"tasks": await list_tasks(request.app.state.redis)}
 
 
 @app.get("/api/tasks/chat/{chat_id}")
@@ -1505,7 +1505,7 @@ async def list_tasks_by_chat_id_endpoint(
     if chat is None or chat.user_id != user.id:
         return {"task_ids": []}
 
-    task_ids = await list_task_ids_by_chat_id(request, chat_id)
+    task_ids = await list_task_ids_by_chat_id(request.app.state.redis, chat_id)
 
     log.debug(f"Task IDs for chat {chat_id}: {task_ids}")
     return {"task_ids": task_ids}

+ 16 - 21
backend/open_webui/tasks.py

@@ -24,11 +24,6 @@ REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
 REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
 
 
-def is_redis(request: Request) -> bool:
-    # Called everywhere a request is available to check Redis
-    return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
-
-
 async def redis_task_command_listener(app):
     redis: Redis = app.state.redis
     pubsub = redis.pubsub()
@@ -83,12 +78,12 @@ async def redis_send_command(redis: Redis, command: dict):
     await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
 
 
-async def cleanup_task(request, task_id: str, id=None):
+async def cleanup_task(redis, task_id: str, id=None):
     """
     Remove a completed or canceled task from the global `tasks` dictionary.
     """
-    if is_redis(request):
-        await redis_cleanup_task(request.app.state.redis, task_id, id)
+    if redis:
+        await redis_cleanup_task(redis, task_id, id)
 
     tasks.pop(task_id, None)  # Remove the task if it exists
 
@@ -99,7 +94,7 @@ async def cleanup_task(request, task_id: str, id=None):
             chat_tasks.pop(id, None)
 
 
-async def create_task(request, coroutine, id=None):
+async def create_task(redis, coroutine, id=None):
     """
     Create a new asyncio task and add it to the global task dictionary.
     """
@@ -108,7 +103,7 @@ async def create_task(request, coroutine, id=None):
 
     # Add a done callback for cleanup
     task.add_done_callback(
-        lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
+        lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))
     )
     tasks[task_id] = task
 
@@ -118,38 +113,38 @@ async def create_task(request, coroutine, id=None):
     else:
         chat_tasks[id] = [task_id]
 
-    if is_redis(request):
-        await redis_save_task(request.app.state.redis, task_id, id)
+    if redis:
+        await redis_save_task(redis, task_id, id)
 
     return task_id, task
 
 
-async def list_tasks(request):
+async def list_tasks(redis):
     """
     List all currently active task IDs.
     """
-    if is_redis(request):
-        return await redis_list_tasks(request.app.state.redis)
+    if redis:
+        return await redis_list_tasks(redis)
     return list(tasks.keys())
 
 
-async def list_task_ids_by_chat_id(request, id):
+async def list_task_ids_by_chat_id(redis, id):
     """
     List all tasks associated with a specific ID.
     """
-    if is_redis(request):
-        return await redis_list_chat_tasks(request.app.state.redis, id)
+    if redis:
+        return await redis_list_chat_tasks(redis, id)
     return chat_tasks.get(id, [])
 
 
-async def stop_task(request, task_id: str):
+async def stop_task(redis, task_id: str):
     """
     Cancel a running task and remove it from the global task list.
     """
-    if is_redis(request):
+    if redis:
         # PUBSUB: All instances check if they have this task, and stop if so.
         await redis_send_command(
-            request.app.state.redis,
+            redis,
             {
                 "action": "stop",
                 "task_id": task_id,