1
0

tasks.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # tasks.py
  2. import asyncio
  3. from typing import Dict
  4. from uuid import uuid4
  5. import json
  6. import logging
  7. from redis.asyncio import Redis
  8. from fastapi import Request
  9. from typing import Dict, List, Optional
  10. from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  13. # A dictionary to keep track of active tasks
  14. tasks: Dict[str, asyncio.Task] = {}
  15. item_tasks = {}
  16. REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks"
  17. REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item"
  18. REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands"
  19. async def redis_task_command_listener(app):
  20. redis: Redis = app.state.redis
  21. pubsub = redis.pubsub()
  22. await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
  23. async for message in pubsub.listen():
  24. if message["type"] != "message":
  25. continue
  26. try:
  27. command = json.loads(message["data"])
  28. if command.get("action") == "stop":
  29. task_id = command.get("task_id")
  30. local_task = tasks.get(task_id)
  31. if local_task:
  32. local_task.cancel()
  33. except Exception as e:
  34. log.exception(f"Error handling distributed task command: {e}")
  35. ### ------------------------------
  36. ### REDIS-ENABLED HANDLERS
  37. ### ------------------------------
  38. async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]):
  39. pipe = redis.pipeline()
  40. pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "")
  41. if item_id:
  42. pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
  43. await pipe.execute()
  44. async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]):
  45. pipe = redis.pipeline()
  46. pipe.hdel(REDIS_TASKS_KEY, task_id)
  47. if item_id:
  48. pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
  49. if (await pipe.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}").execute())[-1] == 0:
  50. pipe.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") # Remove if empty set
  51. await pipe.execute()
  52. async def redis_list_tasks(redis: Redis) -> List[str]:
  53. return list(await redis.hkeys(REDIS_TASKS_KEY))
  54. async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]:
  55. return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}"))
  56. async def redis_send_command(redis: Redis, command: dict):
  57. await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
  58. async def cleanup_task(redis, task_id: str, id=None):
  59. """
  60. Remove a completed or canceled task from the global `tasks` dictionary.
  61. """
  62. if redis:
  63. await redis_cleanup_task(redis, task_id, id)
  64. tasks.pop(task_id, None) # Remove the task if it exists
  65. # If an ID is provided, remove the task from the item_tasks dictionary
  66. if id and task_id in item_tasks.get(id, []):
  67. item_tasks[id].remove(task_id)
  68. if not item_tasks[id]: # If no tasks left for this ID, remove the entry
  69. item_tasks.pop(id, None)
  70. async def create_task(redis, coroutine, id=None):
  71. """
  72. Create a new asyncio task and add it to the global task dictionary.
  73. """
  74. task_id = str(uuid4()) # Generate a unique ID for the task
  75. task = asyncio.create_task(coroutine) # Create the task
  76. # Add a done callback for cleanup
  77. task.add_done_callback(
  78. lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))
  79. )
  80. tasks[task_id] = task
  81. # If an ID is provided, associate the task with that ID
  82. if item_tasks.get(id):
  83. item_tasks[id].append(task_id)
  84. else:
  85. item_tasks[id] = [task_id]
  86. if redis:
  87. await redis_save_task(redis, task_id, id)
  88. return task_id, task
  89. async def list_tasks(redis):
  90. """
  91. List all currently active task IDs.
  92. """
  93. if redis:
  94. return await redis_list_tasks(redis)
  95. return list(tasks.keys())
  96. async def list_task_ids_by_item_id(redis, id):
  97. """
  98. List all tasks associated with a specific ID.
  99. """
  100. if redis:
  101. return await redis_list_item_tasks(redis, id)
  102. return item_tasks.get(id, [])
  103. async def stop_task(redis, task_id: str):
  104. """
  105. Cancel a running task and remove it from the global task list.
  106. """
  107. if redis:
  108. # PUBSUB: All instances check if they have this task, and stop if so.
  109. await redis_send_command(
  110. redis,
  111. {
  112. "action": "stop",
  113. "task_id": task_id,
  114. },
  115. )
  116. # Optionally check if task_id still in Redis a few moments later for feedback?
  117. return {"status": True, "message": f"Stop signal sent for {task_id}"}
  118. task = tasks.pop(task_id, None)
  119. if not task:
  120. return {"status": False, "message": f"Task with ID {task_id} not found."}
  121. task.cancel() # Request task cancellation
  122. try:
  123. await task # Wait for the task to handle the cancellation
  124. except asyncio.CancelledError:
  125. # Task successfully canceled
  126. return {"status": True, "message": f"Task {task_id} successfully stopped."}
  127. return {"status": False, "message": f"Failed to stop task {task_id}."}
  128. async def stop_item_tasks(redis: Redis, item_id: str):
  129. """
  130. Stop all tasks associated with a specific item ID.
  131. """
  132. task_ids = await list_task_ids_by_item_id(redis, item_id)
  133. if not task_ids:
  134. return {"status": True, "message": f"No tasks found for item {item_id}."}
  135. for task_id in task_ids:
  136. result = await stop_task(redis, task_id)
  137. if not result["status"]:
  138. return result # Return the first failure
  139. return {"status": True, "message": f"All tasks for item {item_id} stopped."}