1
0

prompts.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from typing import Optional
  2. from fastapi import APIRouter, Depends, HTTPException, status, Request
  3. from open_webui.models.prompts import (
  4. PromptForm,
  5. PromptUserResponse,
  6. PromptModel,
  7. Prompts,
  8. )
  9. from open_webui.constants import ERROR_MESSAGES
  10. from open_webui.utils.auth import get_admin_user, get_verified_user
  11. from open_webui.utils.access_control import has_access, has_permission
  12. from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
  13. router = APIRouter()
  14. ############################
  15. # GetPrompts
  16. ############################
  17. @router.get("/", response_model=list[PromptModel])
  18. async def get_prompts(user=Depends(get_verified_user)):
  19. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  20. prompts = Prompts.get_prompts()
  21. else:
  22. prompts = Prompts.get_prompts_by_user_id(user.id, "read")
  23. return prompts
  24. @router.get("/list", response_model=list[PromptUserResponse])
  25. async def get_prompt_list(user=Depends(get_verified_user)):
  26. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  27. prompts = Prompts.get_prompts()
  28. else:
  29. prompts = Prompts.get_prompts_by_user_id(user.id, "write")
  30. return prompts
  31. ############################
  32. # CreateNewPrompt
  33. ############################
  34. @router.post("/create", response_model=Optional[PromptModel])
  35. async def create_new_prompt(
  36. request: Request, form_data: PromptForm, user=Depends(get_verified_user)
  37. ):
  38. if user.role != "admin" and not has_permission(
  39. user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
  40. ):
  41. raise HTTPException(
  42. status_code=status.HTTP_401_UNAUTHORIZED,
  43. detail=ERROR_MESSAGES.UNAUTHORIZED,
  44. )
  45. prompt = Prompts.get_prompt_by_command(form_data.command)
  46. if prompt is None:
  47. prompt = Prompts.insert_new_prompt(user.id, form_data)
  48. if prompt:
  49. return prompt
  50. raise HTTPException(
  51. status_code=status.HTTP_400_BAD_REQUEST,
  52. detail=ERROR_MESSAGES.DEFAULT(),
  53. )
  54. raise HTTPException(
  55. status_code=status.HTTP_400_BAD_REQUEST,
  56. detail=ERROR_MESSAGES.COMMAND_TAKEN,
  57. )
  58. ############################
  59. # GetPromptByCommand
  60. ############################
  61. @router.get("/command/{command}", response_model=Optional[PromptModel])
  62. async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
  63. prompt = Prompts.get_prompt_by_command(f"/{command}")
  64. if prompt:
  65. if (
  66. user.role == "admin"
  67. or prompt.user_id == user.id
  68. or has_access(user.id, "read", prompt.access_control)
  69. ):
  70. return prompt
  71. else:
  72. raise HTTPException(
  73. status_code=status.HTTP_401_UNAUTHORIZED,
  74. detail=ERROR_MESSAGES.NOT_FOUND,
  75. )
  76. ############################
  77. # UpdatePromptByCommand
  78. ############################
  79. @router.post("/command/{command}/update", response_model=Optional[PromptModel])
  80. async def update_prompt_by_command(
  81. command: str,
  82. form_data: PromptForm,
  83. user=Depends(get_verified_user),
  84. ):
  85. prompt = Prompts.get_prompt_by_command(f"/{command}")
  86. if not prompt:
  87. raise HTTPException(
  88. status_code=status.HTTP_401_UNAUTHORIZED,
  89. detail=ERROR_MESSAGES.NOT_FOUND,
  90. )
  91. # Is the user the original creator, in a group with write access, or an admin
  92. if (
  93. prompt.user_id != user.id
  94. and not has_access(user.id, "write", prompt.access_control)
  95. and user.role != "admin"
  96. ):
  97. raise HTTPException(
  98. status_code=status.HTTP_401_UNAUTHORIZED,
  99. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  100. )
  101. prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
  102. if prompt:
  103. return prompt
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  108. )
  109. ############################
  110. # DeletePromptByCommand
  111. ############################
  112. @router.delete("/command/{command}/delete", response_model=bool)
  113. async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
  114. prompt = Prompts.get_prompt_by_command(f"/{command}")
  115. if not prompt:
  116. raise HTTPException(
  117. status_code=status.HTTP_401_UNAUTHORIZED,
  118. detail=ERROR_MESSAGES.NOT_FOUND,
  119. )
  120. if (
  121. prompt.user_id != user.id
  122. and not has_access(user.id, "write", prompt.access_control)
  123. and user.role != "admin"
  124. ):
  125. raise HTTPException(
  126. status_code=status.HTTP_401_UNAUTHORIZED,
  127. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  128. )
  129. result = Prompts.delete_prompt_by_command(f"/{command}")
  130. return result