tasks.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # tasks.py
  2. import asyncio
  3. from typing import Dict
  4. from uuid import uuid4
  5. import json
  6. from redis import Redis
  7. from fastapi import Request
  8. from typing import Dict, List, Optional
  9. # A dictionary to keep track of active tasks
  10. tasks: Dict[str, asyncio.Task] = {}
  11. chat_tasks = {}
  12. REDIS_TASKS_KEY = "open-webui:tasks"
  13. REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
  14. REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
  15. def is_redis(request: Request) -> bool:
  16. # Called everywhere a request is available to check Redis
  17. return hasattr(request.app.state, "redis") and isinstance(
  18. request.app.state.redis, Redis
  19. )
  20. async def redis_task_command_listener(app):
  21. redis: Redis = app.state.redis
  22. pubsub = redis.pubsub()
  23. await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
  24. print("Subscribed to Redis task command channel")
  25. async for message in pubsub.listen():
  26. if message["type"] != "message":
  27. continue
  28. try:
  29. command = json.loads(message["data"])
  30. if command.get("action") == "stop":
  31. task_id = command.get("task_id")
  32. local_task = tasks.get(task_id)
  33. if local_task:
  34. local_task.cancel()
  35. except Exception as e:
  36. print(f"Error handling distributed task command: {e}")
  37. ### ------------------------------
  38. ### REDIS-ENABLED HANDLERS
  39. ### ------------------------------
  40. def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
  41. pipe = redis.pipeline()
  42. pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
  43. if chat_id:
  44. pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
  45. pipe.execute()
  46. def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
  47. pipe = redis.pipeline()
  48. pipe.hdel(REDIS_TASKS_KEY, task_id)
  49. if chat_id:
  50. pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
  51. if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
  52. pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
  53. pipe.execute()
  54. def redis_list_tasks(redis: Redis) -> List[str]:
  55. return list(redis.hkeys(REDIS_TASKS_KEY))
  56. def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
  57. return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
  58. def redis_send_command(redis: Redis, command: dict):
  59. redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
  60. def cleanup_task(request, task_id: str, id=None):
  61. """
  62. Remove a completed or canceled task from the global `tasks` dictionary.
  63. """
  64. if is_redis(request):
  65. redis_cleanup_task(request.app.state.redis, task_id, id)
  66. tasks.pop(task_id, None) # Remove the task if it exists
  67. # If an ID is provided, remove the task from the chat_tasks dictionary
  68. if id and task_id in chat_tasks.get(id, []):
  69. chat_tasks[id].remove(task_id)
  70. if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
  71. chat_tasks.pop(id, None)
  72. def create_task(request, coroutine, id=None):
  73. """
  74. Create a new asyncio task and add it to the global task dictionary.
  75. """
  76. task_id = str(uuid4()) # Generate a unique ID for the task
  77. task = asyncio.create_task(coroutine) # Create the task
  78. # Add a done callback for cleanup
  79. task.add_done_callback(lambda t: cleanup_task(request, task_id, id))
  80. tasks[task_id] = task
  81. # If an ID is provided, associate the task with that ID
  82. if chat_tasks.get(id):
  83. chat_tasks[id].append(task_id)
  84. else:
  85. chat_tasks[id] = [task_id]
  86. if is_redis(request):
  87. redis_save_task(request.app.state.redis, task_id, id)
  88. return task_id, task
  89. def list_tasks(request):
  90. """
  91. List all currently active task IDs.
  92. """
  93. if is_redis(request):
  94. return redis_list_tasks(request.app.state.redis)
  95. return list(tasks.keys())
  96. def list_task_ids_by_chat_id(request, id):
  97. """
  98. List all tasks associated with a specific ID.
  99. """
  100. if is_redis(request):
  101. return redis_list_chat_tasks(request.app.state.redis, id)
  102. return chat_tasks.get(id, [])
  103. async def stop_task(request, task_id: str):
  104. """
  105. Cancel a running task and remove it from the global task list.
  106. """
  107. if is_redis(request):
  108. # PUBSUB: All instances check if they have this task, and stop if so.
  109. redis_send_command(
  110. request.app.state.redis,
  111. {
  112. "action": "stop",
  113. "task_id": task_id,
  114. },
  115. )
  116. # Optionally check if task_id still in Redis a few moments later for feedback?
  117. return {"status": True, "message": f"Stop signal sent for {task_id}"}
  118. task = tasks.get(task_id)
  119. if not task:
  120. raise ValueError(f"Task with ID {task_id} not found.")
  121. task.cancel() # Request task cancellation
  122. try:
  123. await task # Wait for the task to handle the cancellation
  124. except asyncio.CancelledError:
  125. # Task successfully canceled
  126. tasks.pop(task_id, None) # Remove it from the dictionary
  127. return {"status": True, "message": f"Task {task_id} successfully stopped."}
  128. return {"status": False, "message": f"Failed to stop task {task_id}."}